From 53e7fa6eb41260eadf80e7346c45ac87e6141d83 Mon Sep 17 00:00:00 2001 From: Lennart Purucker Date: Tue, 9 Jan 2024 18:08:56 +0100 Subject: [PATCH 01/28] mark production tests --- openml/datasets/dataset.py | 3 ++- openml/utils.py | 2 +- tests/test_datasets/test_dataset.py | 4 +++- tests/test_datasets/test_dataset_functions.py | 14 ++++++++++++-- .../test_evaluations/test_evaluation_functions.py | 10 ++++++++++ .../test_sklearn_extension.py | 3 +++ tests/test_flows/test_flow.py | 5 ++++- tests/test_flows/test_flow_functions.py | 10 ++++++++++ tests/test_openml/test_config.py | 5 +++++ tests/test_runs/test_run.py | 2 +- tests/test_runs/test_run_functions.py | 10 ++++++++++ tests/test_setups/test_setup_functions.py | 3 +++ tests/test_study/test_study_functions.py | 6 ++++++ tests/test_tasks/test_clustering_task.py | 4 ++++ tests/test_tasks/test_task_functions.py | 3 +++ tests/test_tasks/test_task_methods.py | 2 +- 16 files changed, 78 insertions(+), 8 deletions(-) diff --git a/openml/datasets/dataset.py b/openml/datasets/dataset.py index b898a145d..04137a1a5 100644 --- a/openml/datasets/dataset.py +++ b/openml/datasets/dataset.py @@ -1090,7 +1090,8 @@ def _get_qualities_pickle_file(qualities_file: str) -> str: return qualities_file + ".pkl" -def _read_qualities(qualities_file: Path) -> dict[str, float]: +def _read_qualities(qualities_file: str | Path) -> dict[str, float]: + qualities_file = Path(qualities_file) qualities_pickle_file = Path(_get_qualities_pickle_file(str(qualities_file))) try: with qualities_pickle_file.open("rb") as fh_binary: diff --git a/openml/utils.py b/openml/utils.py index a3e11229e..63b5ac23e 100644 --- a/openml/utils.py +++ b/openml/utils.py @@ -313,7 +313,7 @@ def _list_all( # noqa: C901, PLR0912 # max number of results to be shown LIMIT = active_filters.pop("size", None) - if LIMIT is None or not isinstance(LIMIT, int) or not np.isinf(LIMIT): + if (LIMIT is not None) and (not isinstance(LIMIT, int)) and (not np.isinf(LIMIT)): raise ValueError(f"'limit' should be an integer or inf but got {LIMIT}") if LIMIT is not None and BATCH_SIZE_ORIG > LIMIT: diff --git a/tests/test_datasets/test_dataset.py b/tests/test_datasets/test_dataset.py index 977f68757..af0d521c4 100644 --- a/tests/test_datasets/test_dataset.py +++ b/tests/test_datasets/test_dataset.py @@ -16,6 +16,7 @@ from openml.testing import TestBase +@pytest.mark.production() class OpenMLDatasetTest(TestBase): _multiprocess_can_split_ = True @@ -317,7 +318,7 @@ def setUp(self): def test_tagging(self): # tags can be at most 64 alphanumeric (+ underscore) chars - unique_indicator = str(time()).replace('.', '') + unique_indicator = str(time()).replace(".", "") tag = f"test_tag_OpenMLDatasetTestOnTestServer_{unique_indicator}" datasets = openml.datasets.list_datasets(tag=tag, output_format="dataframe") assert datasets.empty @@ -329,6 +330,7 @@ def test_tagging(self): datasets = openml.datasets.list_datasets(tag=tag, output_format="dataframe") assert datasets.empty +@pytest.mark.production() class OpenMLDatasetTestSparse(TestBase): _multiprocess_can_split_ = True diff --git a/tests/test_datasets/test_dataset_functions.py b/tests/test_datasets/test_dataset_functions.py index 0435c30ef..a51ccf7c9 100644 --- a/tests/test_datasets/test_dataset_functions.py +++ b/tests/test_datasets/test_dataset_functions.py @@ -132,6 +132,7 @@ def test_list_datasets_empty(self): ) assert datasets.empty + @pytest.mark.production() def test_check_datasets_active(self): # Have to test on live because there is no deactivated dataset on the test server. openml.config.server = self.production_server @@ -155,7 +156,7 @@ def test_illegal_character_tag(self): tag = "illegal_tag&" try: dataset.push_tag(tag) - assert False + raise AssertionError() except openml.exceptions.OpenMLServerException as e: assert e.code == 477 @@ -164,7 +165,7 @@ def test_illegal_length_tag(self): tag = "a" * 65 try: dataset.push_tag(tag) - assert False + raise AssertionError() except openml.exceptions.OpenMLServerException as e: assert e.code == 477 @@ -206,6 +207,7 @@ def _datasets_retrieved_successfully(self, dids, metadata_only=True): ), ) + @pytest.mark.production() def test__name_to_id_with_deactivated(self): """Check that an activated dataset is returned if an earlier deactivated one exists.""" openml.config.server = self.production_server @@ -213,16 +215,19 @@ def test__name_to_id_with_deactivated(self): assert openml.datasets.functions._name_to_id("anneal") == 2 openml.config.server = self.test_server + @pytest.mark.production() def test__name_to_id_with_multiple_active(self): """With multiple active datasets, retrieve the least recent active.""" openml.config.server = self.production_server assert openml.datasets.functions._name_to_id("iris") == 61 + @pytest.mark.production() def test__name_to_id_with_version(self): """With multiple active datasets, retrieve the least recent active.""" openml.config.server = self.production_server assert openml.datasets.functions._name_to_id("iris", version=3) == 969 + @pytest.mark.production() def test__name_to_id_with_multiple_active_error(self): """With multiple active datasets, retrieve the least recent active.""" openml.config.server = self.production_server @@ -283,6 +288,7 @@ def test_get_datasets_lazy(self): datasets[1].get_data() self._datasets_retrieved_successfully([1, 2], metadata_only=False) + @pytest.mark.production() def test_get_dataset_by_name(self): dataset = openml.datasets.get_dataset("anneal") assert type(dataset) == OpenMLDataset @@ -312,6 +318,7 @@ def test_get_dataset_uint8_dtype(self): df, _, _, _ = dataset.get_data() assert df["carbon"].dtype == "uint8" + @pytest.mark.production() def test_get_dataset(self): # This is the only non-lazy load to ensure default behaviour works. dataset = openml.datasets.get_dataset(1) @@ -326,6 +333,7 @@ def test_get_dataset(self): openml.config.server = self.production_server self.assertRaises(OpenMLPrivateDatasetError, openml.datasets.get_dataset, 45) + @pytest.mark.production() def test_get_dataset_lazy(self): dataset = openml.datasets.get_dataset(1, download_data=False) assert type(dataset) == OpenMLDataset @@ -1550,6 +1558,7 @@ def test_data_fork(self): data_id=999999, ) + @pytest.mark.production() def test_get_dataset_parquet(self): # Parquet functionality is disabled on the test server # There is no parquet-copy of the test server yet. @@ -1559,6 +1568,7 @@ def test_get_dataset_parquet(self): assert dataset.parquet_file is not None assert os.path.isfile(dataset.parquet_file) + @pytest.mark.production() def test_list_datasets_with_high_size_parameter(self): # Testing on prod since concurrent deletion of uploded datasets make the test fail openml.config.server = self.production_server diff --git a/tests/test_evaluations/test_evaluation_functions.py b/tests/test_evaluations/test_evaluation_functions.py index c9cccff30..7af01384f 100644 --- a/tests/test_evaluations/test_evaluation_functions.py +++ b/tests/test_evaluations/test_evaluation_functions.py @@ -51,6 +51,7 @@ def _check_list_evaluation_setups(self, **kwargs): self.assertSequenceEqual(sorted(list1), sorted(list2)) return evals_setups + @pytest.mark.production() def test_evaluation_list_filter_task(self): openml.config.server = self.production_server @@ -70,6 +71,7 @@ def test_evaluation_list_filter_task(self): assert evaluations[run_id].value is not None assert evaluations[run_id].values is None + @pytest.mark.production() def test_evaluation_list_filter_uploader_ID_16(self): openml.config.server = self.production_server @@ -84,6 +86,7 @@ def test_evaluation_list_filter_uploader_ID_16(self): assert len(evaluations) > 50 + @pytest.mark.production() def test_evaluation_list_filter_uploader_ID_10(self): openml.config.server = self.production_server @@ -102,6 +105,7 @@ def test_evaluation_list_filter_uploader_ID_10(self): assert evaluations[run_id].value is not None assert evaluations[run_id].values is None + @pytest.mark.production() def test_evaluation_list_filter_flow(self): openml.config.server = self.production_server @@ -121,6 +125,7 @@ def test_evaluation_list_filter_flow(self): assert evaluations[run_id].value is not None assert evaluations[run_id].values is None + @pytest.mark.production() def test_evaluation_list_filter_run(self): openml.config.server = self.production_server @@ -140,6 +145,7 @@ def test_evaluation_list_filter_run(self): assert evaluations[run_id].value is not None assert evaluations[run_id].values is None + @pytest.mark.production() def test_evaluation_list_limit(self): openml.config.server = self.production_server @@ -157,6 +163,7 @@ def test_list_evaluations_empty(self): assert isinstance(evaluations, dict) + @pytest.mark.production() def test_evaluation_list_per_fold(self): openml.config.server = self.production_server size = 1000 @@ -194,6 +201,7 @@ def test_evaluation_list_per_fold(self): assert evaluations[run_id].value is not None assert evaluations[run_id].values is None + @pytest.mark.production() def test_evaluation_list_sort(self): openml.config.server = self.production_server size = 10 @@ -230,6 +238,7 @@ def test_list_evaluation_measures(self): assert isinstance(measures, list) is True assert all(isinstance(s, str) for s in measures) is True + @pytest.mark.production() def test_list_evaluations_setups_filter_flow(self): openml.config.server = self.production_server flow_id = [405] @@ -248,6 +257,7 @@ def test_list_evaluations_setups_filter_flow(self): keys = list(evals["parameters"].values[0].keys()) assert all(elem in columns for elem in keys) + @pytest.mark.production() def test_list_evaluations_setups_filter_task(self): openml.config.server = self.production_server task_id = [6] diff --git a/tests/test_extensions/test_sklearn_extension/test_sklearn_extension.py b/tests/test_extensions/test_sklearn_extension/test_sklearn_extension.py index 44612ca61..4c7b0d60e 100644 --- a/tests/test_extensions/test_sklearn_extension/test_sklearn_extension.py +++ b/tests/test_extensions/test_sklearn_extension/test_sklearn_extension.py @@ -273,6 +273,7 @@ def test_serialize_model(self): self.assertDictEqual(structure, structure_fixture) @pytest.mark.sklearn() + @pytest.mark.production() def test_can_handle_flow(self): openml.config.server = self.production_server @@ -1942,6 +1943,7 @@ def predict_proba(*args, **kwargs): ) == X_test.shape[0] * len(task.class_labels) @pytest.mark.sklearn() + @pytest.mark.production() def test_run_model_on_fold_regression(self): # There aren't any regression tasks on the test server openml.config.server = self.production_server @@ -1992,6 +1994,7 @@ def test_run_model_on_fold_regression(self): ) @pytest.mark.sklearn() + @pytest.mark.production() def test_run_model_on_fold_clustering(self): # There aren't any regression tasks on the test server openml.config.server = self.production_server diff --git a/tests/test_flows/test_flow.py b/tests/test_flows/test_flow.py index 104131806..a20248ed4 100644 --- a/tests/test_flows/test_flow.py +++ b/tests/test_flows/test_flow.py @@ -42,6 +42,7 @@ def setUp(self): def tearDown(self): super().tearDown() + @pytest.mark.production() def test_get_flow(self): # We need to use the production server here because 4024 is not the # test server @@ -74,6 +75,7 @@ def test_get_flow(self): assert subflow_3.parameters["L"] == "-1" assert len(subflow_3.components) == 0 + @pytest.mark.production() def test_get_structure(self): # also responsible for testing: flow.get_subflow # We need to use the production server here because 4024 is not the @@ -103,7 +105,7 @@ def test_tagging(self): flow_id = flows["id"].iloc[0] flow = openml.flows.get_flow(flow_id) # tags can be at most 64 alphanumeric (+ underscore) chars - unique_indicator = str(time()).replace('.', '') + unique_indicator = str(time()).replace(".", "") tag = f"test_tag_TestFlow_{unique_indicator}" flows = openml.flows.list_flows(tag=tag, output_format="dataframe") assert len(flows) == 0 @@ -536,6 +538,7 @@ def test_extract_tags(self): tags = openml.utils.extract_xml_tags("oml:tag", flow_dict["oml:flow"]) assert tags == ["OpenmlWeka", "weka"] + @pytest.mark.production() def test_download_non_scikit_learn_flows(self): openml.config.server = self.production_server diff --git a/tests/test_flows/test_flow_functions.py b/tests/test_flows/test_flow_functions.py index 014c0ac99..68d49eafa 100644 --- a/tests/test_flows/test_flow_functions.py +++ b/tests/test_flows/test_flow_functions.py @@ -44,6 +44,7 @@ def _check_flow(self, flow): ) assert ext_version_str_or_none + @pytest.mark.production() def test_list_flows(self): openml.config.server = self.production_server # We can only perform a smoke test here because we test on dynamic @@ -54,6 +55,7 @@ def test_list_flows(self): for flow in flows.to_dict(orient="index").values(): self._check_flow(flow) + @pytest.mark.production() def test_list_flows_output_format(self): openml.config.server = self.production_server # We can only perform a smoke test here because we test on dynamic @@ -62,11 +64,13 @@ def test_list_flows_output_format(self): assert isinstance(flows, pd.DataFrame) assert len(flows) >= 1500 + @pytest.mark.production() def test_list_flows_empty(self): openml.config.server = self.production_server flows = openml.flows.list_flows(tag="NoOneEverUsesThisTag123", output_format="dataframe") assert flows.empty + @pytest.mark.production() def test_list_flows_by_tag(self): openml.config.server = self.production_server flows = openml.flows.list_flows(tag="weka", output_format="dataframe") @@ -74,6 +78,7 @@ def test_list_flows_by_tag(self): for flow in flows.to_dict(orient="index").values(): self._check_flow(flow) + @pytest.mark.production() def test_list_flows_paginate(self): openml.config.server = self.production_server size = 10 @@ -297,6 +302,7 @@ def test_sklearn_to_flow_list_of_lists(self): assert server_flow.parameters["categories"] == "[[0, 1], [0, 1]]" assert server_flow.model.categories == flow.model.categories + @pytest.mark.production() def test_get_flow1(self): # Regression test for issue #305 # Basically, this checks that a flow without an external version can be loaded @@ -331,6 +337,7 @@ def test_get_flow_reinstantiate_model_no_extension(self): LooseVersion(sklearn.__version__) == "0.19.1", reason="Requires scikit-learn!=0.19.1, because target flow is from that version.", ) + @pytest.mark.production() def test_get_flow_with_reinstantiate_strict_with_wrong_version_raises_exception(self): openml.config.server = self.production_server flow = 8175 @@ -351,6 +358,7 @@ def test_get_flow_with_reinstantiate_strict_with_wrong_version_raises_exception( # Because scikit-learn dropped min_impurity_split hyperparameter in 1.0, # and the requested flow is from 1.0.0 exactly. ) + @pytest.mark.production() def test_get_flow_reinstantiate_flow_not_strict_post_1(self): openml.config.server = self.production_server flow = openml.flows.get_flow(flow_id=19190, reinstantiate=True, strict_version=False) @@ -364,6 +372,7 @@ def test_get_flow_reinstantiate_flow_not_strict_post_1(self): reason="Requires scikit-learn 0.23.2 or ~0.24.", # Because these still have min_impurity_split, but with new scikit-learn module structure." ) + @pytest.mark.production() def test_get_flow_reinstantiate_flow_not_strict_023_and_024(self): openml.config.server = self.production_server flow = openml.flows.get_flow(flow_id=18587, reinstantiate=True, strict_version=False) @@ -375,6 +384,7 @@ def test_get_flow_reinstantiate_flow_not_strict_023_and_024(self): LooseVersion(sklearn.__version__) > "0.23", reason="Requires scikit-learn<=0.23, because the scikit-learn module structure changed.", ) + @pytest.mark.production() def test_get_flow_reinstantiate_flow_not_strict_pre_023(self): openml.config.server = self.production_server flow = openml.flows.get_flow(flow_id=8175, reinstantiate=True, strict_version=False) diff --git a/tests/test_openml/test_config.py b/tests/test_openml/test_config.py index 38bcde16d..4b2d931ee 100644 --- a/tests/test_openml/test_config.py +++ b/tests/test_openml/test_config.py @@ -5,6 +5,8 @@ import tempfile import unittest.mock +import pytest + import openml.config import openml.testing @@ -68,6 +70,7 @@ def test_setup_with_config(self): class TestConfigurationForExamples(openml.testing.TestBase): + @pytest.mark.production() def test_switch_to_example_configuration(self): """Verifies the test configuration is loaded properly.""" # Below is the default test key which would be used anyway, but just for clarity: @@ -79,6 +82,7 @@ def test_switch_to_example_configuration(self): assert openml.config.apikey == "c0c42819af31e706efe1f4b88c23c6c1" assert openml.config.server == self.test_server + @pytest.mark.production() def test_switch_from_example_configuration(self): """Verifies the previous configuration is loaded after stopping.""" # Below is the default test key which would be used anyway, but just for clarity: @@ -100,6 +104,7 @@ def test_example_configuration_stop_before_start(self): openml.config.stop_using_configuration_for_example, ) + @pytest.mark.production() def test_example_configuration_start_twice(self): """Checks that the original config can be returned to if `start..` is called twice.""" openml.config.apikey = "610344db6388d9ba34f6db45a3cf71de" diff --git a/tests/test_runs/test_run.py b/tests/test_runs/test_run.py index e40d33820..ce46b6548 100644 --- a/tests/test_runs/test_run.py +++ b/tests/test_runs/test_run.py @@ -31,7 +31,7 @@ def test_tagging(self): run_id = runs["run_id"].iloc[0] run = openml.runs.get_run(run_id) # tags can be at most 64 alphanumeric (+ underscore) chars - unique_indicator = str(time()).replace('.', '') + unique_indicator = str(time()).replace(".", "") tag = f"test_tag_TestRun_{unique_indicator}" runs = openml.runs.list_runs(tag=tag, output_format="dataframe") assert len(runs) == 0 diff --git a/tests/test_runs/test_run_functions.py b/tests/test_runs/test_run_functions.py index d36935b17..4ce5a07f1 100644 --- a/tests/test_runs/test_run_functions.py +++ b/tests/test_runs/test_run_functions.py @@ -1083,6 +1083,7 @@ def test_local_run_metric_score(self): self._test_local_evaluations(run) + @pytest.mark.production() def test_online_run_metric_score(self): openml.config.server = self.production_server @@ -1389,6 +1390,7 @@ def test__create_trace_from_arff(self): trace_arff = arff.load(arff_file) OpenMLRunTrace.trace_from_arff(trace_arff) + @pytest.mark.production() def test_get_run(self): # this run is not available on test openml.config.server = self.production_server @@ -1424,6 +1426,7 @@ def _check_run(self, run): assert isinstance(run, dict) assert len(run) == 8, str(run) + @pytest.mark.production() def test_get_runs_list(self): # TODO: comes from live, no such lists on test openml.config.server = self.production_server @@ -1440,6 +1443,7 @@ def test_list_runs_output_format(self): runs = openml.runs.list_runs(size=1000, output_format="dataframe") assert isinstance(runs, pd.DataFrame) + @pytest.mark.production() def test_get_runs_list_by_task(self): # TODO: comes from live, no such lists on test openml.config.server = self.production_server @@ -1458,6 +1462,7 @@ def test_get_runs_list_by_task(self): assert run["task_id"] in task_ids self._check_run(run) + @pytest.mark.production() def test_get_runs_list_by_uploader(self): # TODO: comes from live, no such lists on test openml.config.server = self.production_server @@ -1479,6 +1484,7 @@ def test_get_runs_list_by_uploader(self): assert run["uploader"] in uploader_ids self._check_run(run) + @pytest.mark.production() def test_get_runs_list_by_flow(self): # TODO: comes from live, no such lists on test openml.config.server = self.production_server @@ -1497,6 +1503,7 @@ def test_get_runs_list_by_flow(self): assert run["flow_id"] in flow_ids self._check_run(run) + @pytest.mark.production() def test_get_runs_pagination(self): # TODO: comes from live, no such lists on test openml.config.server = self.production_server @@ -1514,6 +1521,7 @@ def test_get_runs_pagination(self): for run in runs.to_dict(orient="index").values(): assert run["uploader"] in uploader_ids + @pytest.mark.production() def test_get_runs_list_by_filters(self): # TODO: comes from live, no such lists on test openml.config.server = self.production_server @@ -1551,6 +1559,7 @@ def test_get_runs_list_by_filters(self): ) assert len(runs) == 2 + @pytest.mark.production() def test_get_runs_list_by_tag(self): # TODO: comes from live, no such lists on test # Unit test works on production server only @@ -1669,6 +1678,7 @@ def test_run_flow_on_task_downloaded_flow(self): TestBase._mark_entity_for_removal("run", run.run_id) TestBase.logger.info("collected from {}: {}".format(__file__.split("/")[-1], run.run_id)) + @pytest.mark.production() def test_format_prediction_non_supervised(self): # non-supervised tasks don't exist on the test server openml.config.server = self.production_server diff --git a/tests/test_setups/test_setup_functions.py b/tests/test_setups/test_setup_functions.py index 5b5023dc8..519009aee 100644 --- a/tests/test_setups/test_setup_functions.py +++ b/tests/test_setups/test_setup_functions.py @@ -132,6 +132,7 @@ def test_get_setup(self): else: assert len(current.parameters) == num_params[idx] + @pytest.mark.production() def test_setup_list_filter_flow(self): openml.config.server = self.production_server @@ -150,6 +151,7 @@ def test_list_setups_empty(self): assert isinstance(setups, dict) + @pytest.mark.production() def test_list_setups_output_format(self): openml.config.server = self.production_server flow_id = 6794 @@ -169,6 +171,7 @@ def test_list_setups_output_format(self): assert isinstance(setups[next(iter(setups.keys()))], Dict) assert len(setups) == 10 + @pytest.mark.production() def test_setuplist_offset(self): # TODO: remove after pull on live for better testing # openml.config.server = self.production_server diff --git a/tests/test_study/test_study_functions.py b/tests/test_study/test_study_functions.py index b66b3b1e7..721c81f9e 100644 --- a/tests/test_study/test_study_functions.py +++ b/tests/test_study/test_study_functions.py @@ -12,6 +12,7 @@ class TestStudyFunctions(TestBase): _multiprocess_can_split_ = True + @pytest.mark.production() def test_get_study_old(self): openml.config.server = self.production_server @@ -22,6 +23,7 @@ def test_get_study_old(self): assert len(study.setups) == 30 assert study.runs is None + @pytest.mark.production() def test_get_study_new(self): openml.config.server = self.production_server @@ -32,6 +34,7 @@ def test_get_study_new(self): assert len(study.setups) == 1253 assert len(study.runs) == 1693 + @pytest.mark.production() def test_get_openml100(self): openml.config.server = self.production_server @@ -41,6 +44,7 @@ def test_get_openml100(self): assert isinstance(study_2, openml.study.OpenMLBenchmarkSuite) assert study.study_id == study_2.study_id + @pytest.mark.production() def test_get_study_error(self): openml.config.server = self.production_server @@ -49,6 +53,7 @@ def test_get_study_error(self): ): openml.study.get_study(99) + @pytest.mark.production() def test_get_suite(self): openml.config.server = self.production_server @@ -59,6 +64,7 @@ def test_get_suite(self): assert study.runs is None assert study.setups is None + @pytest.mark.production() def test_get_suite_error(self): openml.config.server = self.production_server diff --git a/tests/test_tasks/test_clustering_task.py b/tests/test_tasks/test_clustering_task.py index 08cc1d451..bc59ad26c 100644 --- a/tests/test_tasks/test_clustering_task.py +++ b/tests/test_tasks/test_clustering_task.py @@ -1,6 +1,8 @@ # License: BSD 3-Clause from __future__ import annotations +import pytest + import openml from openml.exceptions import OpenMLServerException from openml.tasks import TaskType @@ -18,12 +20,14 @@ def setUp(self, n_levels: int = 1): self.task_type = TaskType.CLUSTERING self.estimation_procedure = 17 + @pytest.mark.production() def test_get_dataset(self): # no clustering tasks on test server openml.config.server = self.production_server task = openml.tasks.get_task(self.task_id) task.get_dataset() + @pytest.mark.production() def test_download_task(self): # no clustering tasks on test server openml.config.server = self.production_server diff --git a/tests/test_tasks/test_task_functions.py b/tests/test_tasks/test_task_functions.py index d651c2ad6..3dc776a2b 100644 --- a/tests/test_tasks/test_task_functions.py +++ b/tests/test_tasks/test_task_functions.py @@ -53,6 +53,7 @@ def test__get_estimation_procedure_list(self): assert isinstance(estimation_procedures[0], dict) assert estimation_procedures[0]["task_type_id"] == TaskType.SUPERVISED_CLASSIFICATION + @pytest.mark.production() def test_list_clustering_task(self): # as shown by #383, clustering tasks can give list/dict casting problems openml.config.server = self.production_server @@ -140,6 +141,7 @@ def test__get_task(self): @unittest.skip( "Please await outcome of discussion: https://github.com/openml/OpenML/issues/776", ) + @pytest.mark.production() def test__get_task_live(self): # Test the following task as it used to throw an Unicode Error. # https://github.com/openml/openml-python/issues/378 @@ -203,6 +205,7 @@ def test_get_task_with_cache(self): task = openml.tasks.get_task(1) assert isinstance(task, OpenMLTask) + @pytest.mark.production() def test_get_task_different_types(self): openml.config.server = self.production_server # Regression task diff --git a/tests/test_tasks/test_task_methods.py b/tests/test_tasks/test_task_methods.py index e9cfc5b58..552fbe949 100644 --- a/tests/test_tasks/test_task_methods.py +++ b/tests/test_tasks/test_task_methods.py @@ -18,7 +18,7 @@ def tearDown(self): def test_tagging(self): task = openml.tasks.get_task(1) # anneal; crossvalidation # tags can be at most 64 alphanumeric (+ underscore) chars - unique_indicator = str(time()).replace('.', '') + unique_indicator = str(time()).replace(".", "") tag = f"test_tag_OpenMLTaskMethodsTest_{unique_indicator}" tasks = openml.tasks.list_tasks(tag=tag, output_format="dataframe") assert len(tasks) == 0 From 5cf9f0f7bc5ddd1dbae8390421f74699c13cb08e Mon Sep 17 00:00:00 2001 From: Lennart Purucker Date: Tue, 9 Jan 2024 20:15:45 +0100 Subject: [PATCH 02/28] make production test run --- openml/datasets/dataset.py | 5 ++-- openml/datasets/functions.py | 4 +-- openml/tasks/task.py | 3 +- tests/test_datasets/test_dataset_functions.py | 30 +++++++++---------- tests/test_flows/test_flow.py | 2 +- tests/test_setups/test_setup_functions.py | 4 --- 6 files changed, 21 insertions(+), 27 deletions(-) diff --git a/openml/datasets/dataset.py b/openml/datasets/dataset.py index 04137a1a5..f81ddd23a 100644 --- a/openml/datasets/dataset.py +++ b/openml/datasets/dataset.py @@ -589,7 +589,6 @@ def _load_data(self) -> tuple[pd.DataFrame | scipy.sparse.csr_matrix, list[bool] fpath = self.data_feather_file if self.cache_format == "feather" else self.data_pickle_file logger.info(f"{self.cache_format} load data {self.name}") try: - assert self.data_pickle_file is not None if self.cache_format == "feather": assert self.data_feather_file is not None assert self.feather_attribute_file is not None @@ -599,6 +598,7 @@ def _load_data(self) -> tuple[pd.DataFrame | scipy.sparse.csr_matrix, list[bool] with open(self.feather_attribute_file, "rb") as fh: # noqa: PTH123 categorical, attribute_names = pickle.load(fh) # noqa: S301 else: + assert self.data_pickle_file is not None with open(self.data_pickle_file, "rb") as fh: # noqa: PTH123 data, categorical, attribute_names = pickle.load(fh) # noqa: S301 except FileNotFoundError as e: @@ -681,14 +681,13 @@ def _convert_array_format( if array_format == "array" and not isinstance(data, scipy.sparse.spmatrix): # We encode the categories such that they are integer to be able # to make a conversion to numeric for backward compatibility - def _encode_if_category(column: pd.Series) -> pd.Series: + def _encode_if_category(column: pd.Series | np.ndarray) -> pd.Series | np.ndarray: if column.dtype.name == "category": column = column.cat.codes.astype(np.float32) mask_nan = column == -1 column[mask_nan] = np.nan return column - assert isinstance(data, (pd.DataFrame, pd.Series)) if isinstance(data, pd.DataFrame): columns = { column_name: _encode_if_category(data.loc[:, column_name]) diff --git a/openml/datasets/functions.py b/openml/datasets/functions.py index 099c7b257..7af0c858e 100644 --- a/openml/datasets/functions.py +++ b/openml/datasets/functions.py @@ -771,7 +771,7 @@ def create_dataset( # noqa: C901, PLR0912, PLR0915 if isinstance(data, pd.DataFrame): # infer the row id from the index of the dataset if row_id_attribute is None: - row_id_attribute = str(data.index.name) + row_id_attribute = data.index.name # When calling data.values, the index will be skipped. # We need to reset the index such that it is part of the data. if data.index.name is not None: @@ -1284,7 +1284,7 @@ def _get_dataset_arff( except OpenMLHashException as e: additional_info = f" Raised when downloading dataset {did}." e.args = (e.args[0] + additional_info,) - raise + raise e return output_file_path diff --git a/openml/tasks/task.py b/openml/tasks/task.py index 4d0b47cfb..d0a520042 100644 --- a/openml/tasks/task.py +++ b/openml/tasks/task.py @@ -338,8 +338,7 @@ def get_X_and_y( def _to_dict(self) -> dict[str, dict]: task_container = super()._to_dict() - task_dict = task_container["oml:task_inputs"] - oml_input = task_dict["oml:task_inputs"]["oml:input"] # type: ignore + oml_input = task_container["oml:task_inputs"]["oml:input"] # type: ignore assert isinstance(oml_input, list) oml_input.append({"@name": "target_feature", "#text": self.target_name}) diff --git a/tests/test_datasets/test_dataset_functions.py b/tests/test_datasets/test_dataset_functions.py index a51ccf7c9..9fbb9259a 100644 --- a/tests/test_datasets/test_dataset_functions.py +++ b/tests/test_datasets/test_dataset_functions.py @@ -2,7 +2,7 @@ from __future__ import annotations import os -import pathlib +from pathlib import Path import random import shutil import time @@ -400,8 +400,8 @@ def test__getarff_path_dataset_arff(self): openml.config.set_root_cache_directory(self.static_cache_dir) description = _get_dataset_description(self.workdir, 2) arff_path = _get_dataset_arff(description, cache_directory=self.workdir) - assert isinstance(arff_path, str) - assert os.path.exists(arff_path) + assert isinstance(arff_path, Path) + assert arff_path.exists() def test__download_minio_file_object_does_not_exist(self): self.assertRaisesRegex( @@ -435,7 +435,7 @@ def test__download_minio_file_to_path(self): ), "_download_minio_file can save to a folder by copying the object name" def test__download_minio_file_raises_FileExists_if_destination_in_use(self): - file_destination = pathlib.Path(self.workdir, "custom.pq") + file_destination = Path(self.workdir, "custom.pq") file_destination.touch() self.assertRaises( @@ -447,7 +447,7 @@ def test__download_minio_file_raises_FileExists_if_destination_in_use(self): ) def test__download_minio_file_works_with_bucket_subdirectory(self): - file_destination = pathlib.Path(self.workdir, "custom.pq") + file_destination = Path(self.workdir, "custom.pq") _download_minio_file( source="http://openml1.win.tue.nl/dataset61/dataset_61.pq", destination=file_destination, @@ -463,8 +463,8 @@ def test__get_dataset_parquet_not_cached(self): "oml:id": "20", } path = _get_dataset_parquet(description, cache_directory=self.workdir) - assert isinstance(path, str), "_get_dataset_parquet returns a path" - assert os.path.isfile(path), "_get_dataset_parquet returns path to real file" + assert isinstance(path, Path), "_get_dataset_parquet returns a path" + assert path.is_file(), "_get_dataset_parquet returns path to real file" @mock.patch("openml._api_calls._download_minio_file") def test__get_dataset_parquet_is_cached(self, patch): @@ -477,8 +477,8 @@ def test__get_dataset_parquet_is_cached(self, patch): "oml:id": "30", } path = _get_dataset_parquet(description, cache_directory=None) - assert isinstance(path, str), "_get_dataset_parquet returns a path" - assert os.path.isfile(path), "_get_dataset_parquet returns path to real file" + assert isinstance(path, Path), "_get_dataset_parquet returns a path" + assert path.is_file(), "_get_dataset_parquet returns path to real file" def test__get_dataset_parquet_file_does_not_exist(self): description = { @@ -509,15 +509,15 @@ def test__getarff_md5_issue(self): def test__get_dataset_features(self): features_file = _get_dataset_features_file(self.workdir, 2) - assert isinstance(features_file, str) - features_xml_path = os.path.join(self.workdir, "features.xml") - assert os.path.exists(features_xml_path) + assert isinstance(features_file, Path) + features_xml_path = self.workdir / "features.xml" + assert features_xml_path.exists() def test__get_dataset_qualities(self): qualities = _get_dataset_qualities_file(self.workdir, 2) - assert isinstance(qualities, str) - qualities_xml_path = os.path.join(self.workdir, "qualities.xml") - assert os.path.exists(qualities_xml_path) + assert isinstance(qualities, Path) + qualities_xml_path = self.workdir / "qualities.xml" + assert qualities_xml_path.exists() def test__get_dataset_skip_download(self): dataset = openml.datasets.get_dataset( diff --git a/tests/test_flows/test_flow.py b/tests/test_flows/test_flow.py index a20248ed4..afa31ef63 100644 --- a/tests/test_flows/test_flow.py +++ b/tests/test_flows/test_flow.py @@ -105,7 +105,7 @@ def test_tagging(self): flow_id = flows["id"].iloc[0] flow = openml.flows.get_flow(flow_id) # tags can be at most 64 alphanumeric (+ underscore) chars - unique_indicator = str(time()).replace(".", "") + unique_indicator = str(time.time()).replace(".", "") tag = f"test_tag_TestFlow_{unique_indicator}" flows = openml.flows.list_flows(tag=tag, output_format="dataframe") assert len(flows) == 0 diff --git a/tests/test_setups/test_setup_functions.py b/tests/test_setups/test_setup_functions.py index 519009aee..9e357f6aa 100644 --- a/tests/test_setups/test_setup_functions.py +++ b/tests/test_setups/test_setup_functions.py @@ -171,11 +171,7 @@ def test_list_setups_output_format(self): assert isinstance(setups[next(iter(setups.keys()))], Dict) assert len(setups) == 10 - @pytest.mark.production() def test_setuplist_offset(self): - # TODO: remove after pull on live for better testing - # openml.config.server = self.production_server - size = 10 setups = openml.setups.list_setups(offset=0, size=size) assert len(setups) == size From b84dc2ed03ed6f50b59d1a33292fbd6fa88e5c48 Mon Sep 17 00:00:00 2001 From: Lennart Purucker Date: Wed, 10 Jan 2024 00:12:22 +0100 Subject: [PATCH 03/28] fix test bug -1/N --- openml/extensions/sklearn/extension.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/openml/extensions/sklearn/extension.py b/openml/extensions/sklearn/extension.py index 00bfc7048..3427ca7c9 100644 --- a/openml/extensions/sklearn/extension.py +++ b/openml/extensions/sklearn/extension.py @@ -184,7 +184,7 @@ def remove_all_in_parentheses(string: str) -> str: if closing_parenthesis_expected == 0: break - _end: int = estimator_start + len(long_name[estimator_start:]) + _end: int = estimator_start + len(long_name[estimator_start:]) - 1 model_select_pipeline = long_name[estimator_start:_end] trimmed_pipeline = cls.trim_flow_name(model_select_pipeline, _outer=False) From 9bab91c64ad8c5a848137044c0a526770ec60f3e Mon Sep 17 00:00:00 2001 From: Lennart Purucker Date: Wed, 10 Jan 2024 16:34:41 +0100 Subject: [PATCH 04/28] add retry raise again after refactor --- openml/_api_calls.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/openml/_api_calls.py b/openml/_api_calls.py index b66e7849d..bc41ec1e4 100644 --- a/openml/_api_calls.py +++ b/openml/_api_calls.py @@ -341,6 +341,9 @@ def _send_request( # noqa: C901 response: requests.Response | None = None delay_method = _human_delay if config.retry_policy == "human" else _robot_delay + # Error to raise in case of retrying too often. Will be set to the last observed exception. + retry_raise_e: Exception | None = None + with requests.Session() as session: # Start at one to have a non-zero multiplier for the sleep for retry_counter in range(1, n_retries + 1): @@ -384,10 +387,7 @@ def _send_request( # noqa: C901 # which means trying again might resolve the issue. if e.code != DATABASE_CONNECTION_ERRCODE: raise e - - delay = delay_method(retry_counter) - time.sleep(delay) - + retry_raise_e = e except xml.parsers.expat.ExpatError as e: if request_method != "get" or retry_counter >= n_retries: if response is not None: @@ -399,18 +399,21 @@ def _send_request( # noqa: C901 f"Unexpected server error when calling {url}. Please contact the " f"developers!\n{extra}" ) from e - - delay = delay_method(retry_counter) - time.sleep(delay) - + retry_raise_e = e except ( requests.exceptions.ChunkedEncodingError, requests.exceptions.ConnectionError, requests.exceptions.SSLError, OpenMLHashException, - ): - delay = delay_method(retry_counter) - time.sleep(delay) + ) as e: + retry_raise_e = e + + # We can only be here if there was an exception + assert retry_raise_e is not None + if retry_counter >= n_retries: + raise retry_raise_e + delay = delay_method(retry_counter) + time.sleep(delay) assert response is not None return response From 83ba06f1b133b0996ed4406ca29bfeb6dd081487 Mon Sep 17 00:00:00 2001 From: Lennart Purucker Date: Wed, 10 Jan 2024 16:35:18 +0100 Subject: [PATCH 05/28] fix str dict representation --- openml/tasks/task.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/openml/tasks/task.py b/openml/tasks/task.py index d0a520042..4270476ea 100644 --- a/openml/tasks/task.py +++ b/openml/tasks/task.py @@ -198,11 +198,11 @@ def get_split_dimensions(self) -> tuple[int, int, int]: return self.split.repeats, self.split.folds, self.split.samples # TODO(eddiebergman): Really need some better typing on all this - def _to_dict(self) -> dict[str, dict[str, int | str | list[dict[str, Any]]]]: - """Creates a dictionary representation of self.""" + def _to_dict(self) -> dict[str, dict[str, str | list[dict[str, Any]]]]: + """Creates a dictionary representation of self in a string format (for XML parsing).""" oml_input = [ - {"@name": "source_data", "#text": self.dataset_id}, - {"@name": "estimation_procedure", "#text": self.estimation_procedure_id}, + {"@name": "source_data", "#text": str(self.dataset_id)}, + {"@name": "estimation_procedure", "#text": str(self.estimation_procedure_id)}, ] if self.evaluation_measure is not None: # oml_input.append({"@name": "evaluation_measures", "#text": self.evaluation_measure}) From 8174ea18ece8f3096deb2849f12a860b84d9b4aa Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Wed, 10 Jan 2024 17:40:18 +0100 Subject: [PATCH 06/28] test: Fix non-writable home mocks --- openml/config.py | 31 +++++++++++++++++-------------- tests/test_openml/test_config.py | 14 +++++++------- 2 files changed, 24 insertions(+), 21 deletions(-) diff --git a/openml/config.py b/openml/config.py index 6ce07a6ce..4744dbe86 100644 --- a/openml/config.py +++ b/openml/config.py @@ -243,14 +243,11 @@ def _setup(config: _Config | None = None) -> None: config_dir = config_file.parent # read config file, create directory for config file - if not config_dir.exists(): - try: + try: + if not config_dir.exists(): config_dir.mkdir(exist_ok=True, parents=True) - cache_exists = True - except PermissionError: - cache_exists = False - else: - cache_exists = True + except PermissionError: + pass if config is None: config = _parse_config(config_file) @@ -264,15 +261,21 @@ def _setup(config: _Config | None = None) -> None: set_retry_policy(config["retry_policy"], n_retries) _root_cache_directory = short_cache_dir.expanduser().resolve() + + try: + cache_exists = _root_cache_directory.exists() + except PermissionError: + cache_exists = False + # create the cache subdirectory - if not _root_cache_directory.exists(): - try: + try: + if not _root_cache_directory.exists(): _root_cache_directory.mkdir(exist_ok=True, parents=True) - except PermissionError: - openml_logger.warning( - "No permission to create openml cache directory at %s! This can result in " - "OpenML-Python not working properly." % _root_cache_directory, - ) + except PermissionError: + openml_logger.warning( + "No permission to create openml cache directory at %s! This can result in " + "OpenML-Python not working properly." % _root_cache_directory, + ) if cache_exists: _create_log_handlers() diff --git a/tests/test_openml/test_config.py b/tests/test_openml/test_config.py index 4b2d931ee..bfb88a5db 100644 --- a/tests/test_openml/test_config.py +++ b/tests/test_openml/test_config.py @@ -4,6 +4,8 @@ import os import tempfile import unittest.mock +from copy import copy +from pathlib import Path import pytest @@ -12,22 +14,20 @@ class TestConfig(openml.testing.TestBase): - @unittest.mock.patch("os.path.expanduser") @unittest.mock.patch("openml.config.openml_logger.warning") @unittest.mock.patch("openml.config._create_log_handlers") @unittest.skipIf(os.name == "nt", "https://github.com/openml/openml-python/issues/1033") - def test_non_writable_home(self, log_handler_mock, warnings_mock, expanduser_mock): + def test_non_writable_home(self, log_handler_mock, warnings_mock): with tempfile.TemporaryDirectory(dir=self.workdir) as td: - expanduser_mock.side_effect = ( - os.path.join(td, "openmldir"), - os.path.join(td, "cachedir"), - ) os.chmod(td, 0o444) - openml.config._setup() + _dd = copy(openml.config._defaults) + _dd["cachedir"] = Path(td) / "something-else" + openml.config._setup(_dd) assert warnings_mock.call_count == 2 assert log_handler_mock.call_count == 1 assert not log_handler_mock.call_args_list[0][1]["create_file_handler"] + assert openml.config._root_cache_directory == Path(td) / "something-else" @unittest.mock.patch("os.path.expanduser") def test_XDG_directories_do_not_exist(self, expanduser_mock): From d09f4f85b0b12ea6e4144134e54c0172eca3a780 Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Wed, 10 Jan 2024 18:00:08 +0100 Subject: [PATCH 07/28] testing: not not a change --- openml/runs/run.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/openml/runs/run.py b/openml/runs/run.py index 901e97d3c..a53184895 100644 --- a/openml/runs/run.py +++ b/openml/runs/run.py @@ -369,10 +369,8 @@ def to_filesystem( directory = Path(directory) directory.mkdir(exist_ok=True, parents=True) - if not any(directory.iterdir()): - raise ValueError( - f"Output directory {directory.expanduser().resolve()} should be empty", - ) + if any(directory.iterdir()): + raise ValueError(f"Output directory {directory.expanduser().resolve()} should be empty") run_xml = self._to_xml() predictions_arff = arff.dumps(self._generate_arff_dict()) From cbd23394e922251274753fe8bca069e30a0bd511 Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Thu, 11 Jan 2024 11:08:48 +0100 Subject: [PATCH 08/28] testing: trigger CI From 8d769e7bf99d9387ae59f4c5bb4920f3e1e64bb9 Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Thu, 11 Jan 2024 11:17:52 +0100 Subject: [PATCH 09/28] typing: Update typing --- openml/tasks/task.py | 4 ++-- openml/utils.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/openml/tasks/task.py b/openml/tasks/task.py index 4270476ea..fbc0985fb 100644 --- a/openml/tasks/task.py +++ b/openml/tasks/task.py @@ -198,7 +198,7 @@ def get_split_dimensions(self) -> tuple[int, int, int]: return self.split.repeats, self.split.folds, self.split.samples # TODO(eddiebergman): Really need some better typing on all this - def _to_dict(self) -> dict[str, dict[str, str | list[dict[str, Any]]]]: + def _to_dict(self) -> dict[str, dict[str, int | str | list[dict[str, Any]]]]: """Creates a dictionary representation of self in a string format (for XML parsing).""" oml_input = [ {"@name": "source_data", "#text": str(self.dataset_id)}, @@ -210,7 +210,7 @@ def _to_dict(self) -> dict[str, dict[str, str | list[dict[str, Any]]]]: return { "oml:task_inputs": { "@xmlns:oml": "http://openml.org/openml", - "oml:task_type_id": self.task_type_id.value, + "oml:task_type_id": self.task_type_id.value, # This is an int from the enum? "oml:input": oml_input, } } diff --git a/openml/utils.py b/openml/utils.py index 63b5ac23e..80d7caaae 100644 --- a/openml/utils.py +++ b/openml/utils.py @@ -312,7 +312,7 @@ def _list_all( # noqa: C901, PLR0912 raise ValueError(f"'batch_size' should be an integer but got {BATCH_SIZE_ORIG}") # max number of results to be shown - LIMIT = active_filters.pop("size", None) + LIMIT: int | float | None = active_filters.pop("size", None) # type: ignore if (LIMIT is not None) and (not isinstance(LIMIT, int)) and (not np.isinf(LIMIT)): raise ValueError(f"'limit' should be an integer or inf but got {LIMIT}") From 52ad9744846f5a25223b37a21a1418538ed4d353 Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Thu, 11 Jan 2024 11:28:27 +0100 Subject: [PATCH 10/28] ci: Update testing matrix --- .github/workflows/test.yml | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index d178c15df..0912c2be9 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -13,22 +13,27 @@ jobs: strategy: matrix: python-version: ["3.8", "3.9"] - scikit-learn: ["0.21.2", "0.22.2", "0.23.1", "0.24"] + # TODO(eddiebergman): We should consider testing against newer version I guess... + # We probably consider just having a `"1"` version to always test against latest + scikit-learn: ["0.23.1", "0.24"] os: [ubuntu-latest] - sklearn-only: ['true'] - exclude: # no scikit-learn 0.21.2 release for Python 3.8 - - python-version: 3.8 - scikit-learn: 0.21.2 + sklearn-only: ["true"] + exclude: # no scikit-learn 0.23 release for Python 3.9 + - python-version: "3.9" + scikit-learn: "0.23.1" include: - - python-version: 3.8 + # Include a code cov version + - code-cov: true + os: ubuntu-latest + python-version: "3.8" scikit-learn: 0.23.1 - code-cov: true sklearn-only: 'false' - os: ubuntu-latest + # Include a windows test, for some reason on a later version of scikit-learn - os: windows-latest - sklearn-only: 'false' + python-version: "3.8" scikit-learn: 0.24.* - scipy: 1.10.0 + scipy: "1.10.0" # not sure why the explicit scipy version? + sklearn-only: 'false' fail-fast: false max-parallel: 4 From 10eb967291a923db73365798888f7964254fb023 Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Thu, 11 Jan 2024 11:40:30 +0100 Subject: [PATCH 11/28] testing: Fixup run flow error check --- openml/runs/functions.py | 6 ++++-- tests/test_runs/test_run_functions.py | 4 ++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/openml/runs/functions.py b/openml/runs/functions.py index 2848bd9ed..7a082e217 100644 --- a/openml/runs/functions.py +++ b/openml/runs/functions.py @@ -262,12 +262,14 @@ def run_flow_on_task( # noqa: C901, PLR0912, PLR0915, PLR0913 if upload_flow or avoid_duplicate_runs: flow_id = flow_exists(flow.name, flow.external_version) if isinstance(flow.flow_id, int) and flow_id != flow.flow_id: - if flow_id is not None: + if flow_id is not False: raise PyOpenMLError( "Local flow_id does not match server flow_id: " f"'{flow.flow_id}' vs '{flow_id}'", ) - raise PyOpenMLError("Flow does not exist on the server, but 'flow.flow_id' is not None") + raise PyOpenMLError( + "Flow does not exist on the server, but 'flow.flow_id' is not None." + ) if upload_flow and flow_id is None: flow.publish() diff --git a/tests/test_runs/test_run_functions.py b/tests/test_runs/test_run_functions.py index 4ce5a07f1..618e4d46d 100644 --- a/tests/test_runs/test_run_functions.py +++ b/tests/test_runs/test_run_functions.py @@ -1224,7 +1224,7 @@ def test_run_with_illegal_flow_id(self): flow, _ = self._add_sentinel_to_flow_name(flow, None) flow.flow_id = -1 expected_message_regex = ( - "Flow does not exist on the server, " "but 'flow.flow_id' is not None." + r"Flow does not exist on the server, but 'flow.flow_id' is not None." ) with pytest.raises(openml.exceptions.PyOpenMLError, match=expected_message_regex): openml.runs.run_flow_on_task( @@ -1258,7 +1258,7 @@ def test_run_with_illegal_flow_id_after_load(self): loaded_run = openml.runs.OpenMLRun.from_filesystem(cache_path) expected_message_regex = ( - "Flow does not exist on the server, " "but 'flow.flow_id' is not None." + r"Flow does not exist on the server, but 'flow.flow_id' is not None." ) with pytest.raises(openml.exceptions.PyOpenMLError, match=expected_message_regex): loaded_run.publish() From 648b557d898794a4226b15d94d3937e169c036dc Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Thu, 11 Jan 2024 11:45:09 +0100 Subject: [PATCH 12/28] ci: Manual dispatch, disable double testing --- .github/workflows/test.yml | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 0912c2be9..dce63adf5 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -1,6 +1,19 @@ name: Tests -on: [push, pull_request] +on: + workflow_dispatch: + + push: + branches: + - main + - develop + tags: + - "v*.*.*" + + pull_request: + branches: + - main + - develop concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} From 5fc565b6a5dad37aef8cd109749c09a36217f659 Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Thu, 11 Jan 2024 11:49:12 +0100 Subject: [PATCH 13/28] ci: Prevent further ci duplication --- .github/workflows/dist.yaml | 15 ++++++++++++++- .github/workflows/pre-commit.yaml | 15 ++++++++++++++- .github/workflows/release_docker.yaml | 1 + 3 files changed, 29 insertions(+), 2 deletions(-) diff --git a/.github/workflows/dist.yaml b/.github/workflows/dist.yaml index 602b7edcd..d0113d1ff 100644 --- a/.github/workflows/dist.yaml +++ b/.github/workflows/dist.yaml @@ -1,6 +1,19 @@ name: dist-check -on: [push, pull_request] +on: + workflow_dispatch: + + push: + branches: + - main + - develop + tags: + - "v*.*.*" + + pull_request: + branches: + - main + - develop jobs: dist: diff --git a/.github/workflows/pre-commit.yaml b/.github/workflows/pre-commit.yaml index 32cfc6376..c44b3e62f 100644 --- a/.github/workflows/pre-commit.yaml +++ b/.github/workflows/pre-commit.yaml @@ -1,6 +1,19 @@ name: pre-commit -on: [push] +on: + workflow_dispatch: + + push: + branches: + - main + - develop + tags: + - "v*.*.*" + + pull_request: + branches: + - main + - develop jobs: run-all-files: diff --git a/.github/workflows/release_docker.yaml b/.github/workflows/release_docker.yaml index 8de78fbcd..adb27e58e 100644 --- a/.github/workflows/release_docker.yaml +++ b/.github/workflows/release_docker.yaml @@ -1,6 +1,7 @@ name: release-docker on: + workflow_dispatch: push: branches: - 'develop' From 00a5280b8da01ee78bafe530fb08fda974607245 Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Thu, 11 Jan 2024 11:54:06 +0100 Subject: [PATCH 14/28] ci: Add concurrency checks to all --- .github/workflows/dist.yaml | 4 ++++ .github/workflows/docs.yaml | 19 ++++++++++++++++++- .github/workflows/pre-commit.yaml | 4 ++++ .github/workflows/release_docker.yaml | 4 ++++ 4 files changed, 30 insertions(+), 1 deletion(-) diff --git a/.github/workflows/dist.yaml b/.github/workflows/dist.yaml index d0113d1ff..b81651cea 100644 --- a/.github/workflows/dist.yaml +++ b/.github/workflows/dist.yaml @@ -15,6 +15,10 @@ on: - main - develop +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + jobs: dist: runs-on: ubuntu-latest diff --git a/.github/workflows/docs.yaml b/.github/workflows/docs.yaml index 28f51378d..e50d67710 100644 --- a/.github/workflows/docs.yaml +++ b/.github/workflows/docs.yaml @@ -1,5 +1,22 @@ name: Docs -on: [pull_request, push] +on: + workflow_dispatch: + + push: + branches: + - main + - develop + tags: + - "v*.*.*" + + pull_request: + branches: + - main + - develop + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true jobs: build-and-deploy: diff --git a/.github/workflows/pre-commit.yaml b/.github/workflows/pre-commit.yaml index c44b3e62f..9d1ab7fa8 100644 --- a/.github/workflows/pre-commit.yaml +++ b/.github/workflows/pre-commit.yaml @@ -15,6 +15,10 @@ on: - main - develop +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + jobs: run-all-files: runs-on: ubuntu-latest diff --git a/.github/workflows/release_docker.yaml b/.github/workflows/release_docker.yaml index adb27e58e..c8f8c59f8 100644 --- a/.github/workflows/release_docker.yaml +++ b/.github/workflows/release_docker.yaml @@ -12,6 +12,10 @@ on: branches: - 'develop' +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + jobs: docker: From 6cb175d5bed0f709888cc39fb8c22ba009aaf626 Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Thu, 11 Jan 2024 12:04:19 +0100 Subject: [PATCH 15/28] ci: Remove the max-parallel on test ci There are a lot less now and they cancel previous puhes in the same pr now so it shouldn't be a problem anymore --- .github/workflows/test.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index dce63adf5..5cc31ed15 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -48,7 +48,6 @@ jobs: scipy: "1.10.0" # not sure why the explicit scipy version? sklearn-only: 'false' fail-fast: false - max-parallel: 4 steps: - uses: actions/checkout@v4 From 1fc78193416eb4f773c9a9860c2fa8939869239a Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Thu, 11 Jan 2024 13:20:56 +0100 Subject: [PATCH 16/28] testing: Fix windows path generation --- openml/testing.py | 8 ++++---- tests/test_runs/test_run_functions.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/openml/testing.py b/openml/testing.py index 60f4eb4a6..4af361507 100644 --- a/openml/testing.py +++ b/openml/testing.py @@ -76,20 +76,20 @@ def setUp(self, n_levels: int = 1) -> None: # This cache directory is checked in to git to simulate a populated # cache self.maxDiff = None - self.static_cache_dir = None abspath_this_file = Path(inspect.getfile(self.__class__)).absolute() static_cache_dir = abspath_this_file.parent for _ in range(n_levels): static_cache_dir = static_cache_dir.parent.absolute() + content = os.listdir(static_cache_dir) if "files" in content: - self.static_cache_dir = static_cache_dir / "files" - - if self.static_cache_dir is None: + static_cache_dir = static_cache_dir / "files" + else: raise ValueError( f"Cannot find test cache dir, expected it to be {static_cache_dir}!", ) + self.static_cache_dir = static_cache_dir self.cwd = Path.cwd() workdir = Path(__file__).parent.absolute() tmp_dir_name = self.id() diff --git a/tests/test_runs/test_run_functions.py b/tests/test_runs/test_run_functions.py index 618e4d46d..edd7e0198 100644 --- a/tests/test_runs/test_run_functions.py +++ b/tests/test_runs/test_run_functions.py @@ -1386,7 +1386,7 @@ def test__run_task_get_arffcontent(self): self.assertAlmostEqual(sum(arff_line[6:]), 1.0) def test__create_trace_from_arff(self): - with open(self.static_cache_dir + "/misc/trace.arff") as arff_file: + with open(self.static_cache_dir / "misc" / "trace.arff") as arff_file: trace_arff = arff.load(arff_file) OpenMLRunTrace.trace_from_arff(trace_arff) From 162437501c3ba1c6843672612fecf2faf09b536c Mon Sep 17 00:00:00 2001 From: Lennart Purucker Date: Thu, 11 Jan 2024 13:27:44 +0100 Subject: [PATCH 17/28] add pytest for server state --- tests/test_utils/test_utils.py | 30 ++++++++++++++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/tests/test_utils/test_utils.py b/tests/test_utils/test_utils.py index 299d4007b..b53b5a64d 100644 --- a/tests/test_utils/test_utils.py +++ b/tests/test_utils/test_utils.py @@ -2,9 +2,8 @@ import os import unittest.mock - import pytest - +import shutil import openml from openml.testing import _check_dataset @@ -25,6 +24,20 @@ def with_test_server(): openml.config.stop_using_configuration_for_example() +@pytest.fixture(autouse=True) +def with_test_cache(test_files_directory, request): + if not test_files_directory.exists(): + raise ValueError( + f"Cannot find test cache dir, expected it to be {test_files_directory!s}!", + ) + _root_cache_directory = openml.config._root_cache_directory + tmp_cache = _root_cache_directory / request.node.name + openml.config.set_root_cache_directory(tmp_cache) + yield + openml.config.set_root_cache_directory(_root_cache_directory) + shutil.rmtree(tmp_cache) + + @pytest.fixture() def min_number_tasks_on_test_server() -> int: """After a reset at least 1068 tasks are on the test server""" @@ -176,3 +189,16 @@ def test__create_cache_directory(config_mock, tmp_path): match="Cannot create cache directory", ): openml.utils._create_cache_directory("ghi") + + +@pytest.mark.server() +def test_correct_test_server_download_state(): + """This test verifies that the test server downloads the data from the correct source. + + If this tests fails, it is highly likely that the test server is not configured correctly. + Usually, this means that the test server is serving data from the task with the same ID from the production server. + That is, it serves parquet files wrongly associated with the test server's task. + """ + task = openml.tasks.get_task(119) + dataset = task.get_dataset() + assert len(dataset.features) == dataset.get_data(dataset_format="dataframe")[0].shape[1] \ No newline at end of file From 9ca10b6e0fde51cd68bbafa233ce1e8706891384 Mon Sep 17 00:00:00 2001 From: Lennart Purucker Date: Fri, 12 Jan 2024 09:27:49 +0100 Subject: [PATCH 18/28] add assert cache state --- tests/conftest.py | 59 +++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 49 insertions(+), 10 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 8f353b73c..e72e98e96 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -25,8 +25,7 @@ import logging import os -import pathlib - +from pathlib import Path import pytest import openml @@ -53,20 +52,20 @@ def worker_id() -> str: return "master" -def read_file_list() -> list[pathlib.Path]: +def read_file_list() -> list[Path]: """Returns a list of paths to all files that currently exist in 'openml/tests/files/' - :return: List[pathlib.Path] + :return: List[Path] """ - test_files_dir = pathlib.Path(__file__).parent / "files" + test_files_dir = Path(__file__).parent / "files" return [f for f in test_files_dir.rglob("*") if f.is_file()] -def compare_delete_files(old_list: list[pathlib.Path], new_list: list[pathlib.Path]) -> None: +def compare_delete_files(old_list: list[Path], new_list: list[Path]) -> None: """Deletes files that are there in the new_list but not in the old_list - :param old_list: List[pathlib.Path] - :param new_list: List[pathlib.Path] + :param old_list: List[Path] + :param new_list: List[Path] :return: None """ file_list = list(set(new_list) - set(old_list)) @@ -183,16 +182,56 @@ def pytest_addoption(parser): ) +def _expected_static_cache_state(root_dir: Path) -> list[Path]: + _c_root_dir = root_dir/"org"/"openml"/"test" + res_paths = [root_dir, _c_root_dir] + + for _d in ["datasets", "tasks", "runs", "setups"]: + res_paths.append(_c_root_dir / _d) + + for _id in ["-1","2"]: + tmp_p = _c_root_dir / "datasets" / _id + res_paths.extend([ + tmp_p / "dataset.arff", + tmp_p / "features.xml", + tmp_p / "qualities.xml", + tmp_p / "description.xml", + ]) + res_paths.append(_c_root_dir / "datasets" / "30" / "dataset_30.pq") + res_paths.append(_c_root_dir / "runs" / "1" / "description.xml") + res_paths.append(_c_root_dir / "setups" / "1" / "description.xml") + + for _id in ["1","3", "1882"]: + tmp_p = _c_root_dir / "tasks" / _id + res_paths.extend([ + tmp_p / "datasplits.arff", + tmp_p / "task.xml", + ]) + + return res_paths + +def assert_static_test_cache_correct(root_dir: Path) -> None: + for p in _expected_static_cache_state(root_dir): + assert p.exists(), f"Expected path {p} does not exist" + + @pytest.fixture(scope="class") def long_version(request): request.cls.long_version = request.config.getoption("--long") @pytest.fixture() -def test_files_directory() -> pathlib.Path: - return pathlib.Path(__file__).parent / "files" +def test_files_directory() -> Path: + return Path(__file__).parent / "files" @pytest.fixture() def test_api_key() -> str: return "c0c42819af31e706efe1f4b88c23c6c1" + + +@pytest.fixture(autouse=True) +def verify_cache_state(test_files_directory) -> None: + assert_static_test_cache_correct(test_files_directory) + yield + assert_static_test_cache_correct(test_files_directory) From 7c5106d398aef9b5d2c1f26f638a0358baa4e2e2 Mon Sep 17 00:00:00 2001 From: Lennart Purucker Date: Fri, 12 Jan 2024 09:28:57 +0100 Subject: [PATCH 19/28] some formatting --- tests/conftest.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index e72e98e96..62fe3c7e8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -183,7 +183,7 @@ def pytest_addoption(parser): def _expected_static_cache_state(root_dir: Path) -> list[Path]: - _c_root_dir = root_dir/"org"/"openml"/"test" + _c_root_dir = root_dir / "org" / "openml" / "test" res_paths = [root_dir, _c_root_dir] for _d in ["datasets", "tasks", "runs", "setups"]: @@ -197,11 +197,12 @@ def _expected_static_cache_state(root_dir: Path) -> list[Path]: tmp_p / "qualities.xml", tmp_p / "description.xml", ]) + res_paths.append(_c_root_dir / "datasets" / "30" / "dataset_30.pq") res_paths.append(_c_root_dir / "runs" / "1" / "description.xml") res_paths.append(_c_root_dir / "setups" / "1" / "description.xml") - for _id in ["1","3", "1882"]: + for _id in ["1", "3", "1882"]: tmp_p = _c_root_dir / "tasks" / _id res_paths.extend([ tmp_p / "datasplits.arff", @@ -210,6 +211,7 @@ def _expected_static_cache_state(root_dir: Path) -> list[Path]: return res_paths + def assert_static_test_cache_correct(root_dir: Path) -> None: for p in _expected_static_cache_state(root_dir): assert p.exists(), f"Expected path {p} does not exist" From 1ae6573d4e98feccc3cde94d0c2d2dccef3569ff Mon Sep 17 00:00:00 2001 From: Lennart Purucker Date: Fri, 12 Jan 2024 09:35:38 +0100 Subject: [PATCH 20/28] fix with cache fixture --- tests/test_utils/test_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/test_utils/test_utils.py b/tests/test_utils/test_utils.py index b53b5a64d..cae947917 100644 --- a/tests/test_utils/test_utils.py +++ b/tests/test_utils/test_utils.py @@ -31,11 +31,12 @@ def with_test_cache(test_files_directory, request): f"Cannot find test cache dir, expected it to be {test_files_directory!s}!", ) _root_cache_directory = openml.config._root_cache_directory - tmp_cache = _root_cache_directory / request.node.name + tmp_cache = test_files_directory / request.node.name openml.config.set_root_cache_directory(tmp_cache) yield openml.config.set_root_cache_directory(_root_cache_directory) - shutil.rmtree(tmp_cache) + if tmp_cache.exists(): + shutil.rmtree(tmp_cache) @pytest.fixture() From b22fd4d8667d2391eaca6ffa74813d6a1e9f1863 Mon Sep 17 00:00:00 2001 From: Lennart Purucker Date: Fri, 12 Jan 2024 09:51:25 +0100 Subject: [PATCH 21/28] finally remove th finally --- openml/tasks/functions.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/openml/tasks/functions.py b/openml/tasks/functions.py index c12da95a7..c763714bf 100644 --- a/openml/tasks/functions.py +++ b/openml/tasks/functions.py @@ -66,9 +66,8 @@ def _get_cached_task(tid: int) -> OpenMLTask: with task_xml_path.open(encoding="utf8") as fh: return _create_task_from_xml(fh.read()) except OSError as e: - raise OpenMLCacheException(f"Task file for tid {tid} not cached") from e - finally: openml.utils._remove_cache_dir_for_id(TASKS_CACHE_DIR_NAME, tid_cache_dir) + raise OpenMLCacheException(f"Task file for tid {tid} not cached") from e def _get_estimation_procedure_list() -> list[dict[str, Any]]: From 43744be8a1d25835619c323a8b10f1bcef81f860 Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Fri, 12 Jan 2024 10:06:10 +0100 Subject: [PATCH 22/28] doc: Fix link --- doc/contributing.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/contributing.rst b/doc/contributing.rst index e8d537338..89c26746c 100644 --- a/doc/contributing.rst +++ b/doc/contributing.rst @@ -19,7 +19,7 @@ In particular, a few ways to contribute to openml-python are: For more information, see the :ref:`extensions` below. * Bug reports. If something doesn't work for you or is cumbersome, please open a new issue to let - us know about the problem. See `this section `_. + us know about the problem. See `this section `_. * `Cite OpenML `_ if you use it in a scientific publication. From 4a07696efda181581879229e48fe98efd8bd8e1d Mon Sep 17 00:00:00 2001 From: Lennart Purucker Date: Fri, 12 Jan 2024 10:22:30 +0100 Subject: [PATCH 23/28] update test matrix --- .github/workflows/test.yml | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 5cc31ed15..ab60f59c6 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -25,7 +25,7 @@ jobs: runs-on: ${{ matrix.os }} strategy: matrix: - python-version: ["3.8", "3.9"] + python-version: ["3.8"] # TODO(eddiebergman): We should consider testing against newer version I guess... # We probably consider just having a `"1"` version to always test against latest scikit-learn: ["0.23.1", "0.24"] @@ -35,6 +35,11 @@ jobs: - python-version: "3.9" scikit-learn: "0.23.1" include: + - os: ubuntu-latest + python-version: "3.9" + scikit-learn: "0.24" + scipy: "1.10.0" + sklearn-only: "true" # Include a code cov version - code-cov: true os: ubuntu-latest From 305dee007f385c53d583aefae572a202a1311ddf Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Fri, 12 Jan 2024 10:43:01 +0100 Subject: [PATCH 24/28] doc: Update to just point to contributing --- doc/contributing.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/contributing.rst b/doc/contributing.rst index 89c26746c..34d1edb14 100644 --- a/doc/contributing.rst +++ b/doc/contributing.rst @@ -19,7 +19,7 @@ In particular, a few ways to contribute to openml-python are: For more information, see the :ref:`extensions` below. * Bug reports. If something doesn't work for you or is cumbersome, please open a new issue to let - us know about the problem. See `this section `_. + us know about the problem. See `this section `_. * `Cite OpenML `_ if you use it in a scientific publication. From 1b70078fd81e1d897c8c15a70a2e12b6cb3bc740 Mon Sep 17 00:00:00 2001 From: Lennart Purucker Date: Fri, 12 Jan 2024 11:05:50 +0100 Subject: [PATCH 25/28] add linkcheck ignore for test server --- doc/conf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/conf.py b/doc/conf.py index a10187486..61ba4a46c 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -119,7 +119,7 @@ # # currently disabled because without intersphinx we cannot link to numpy.ndarray # nitpicky = True - +linkcheck_ignore = [r"https://test.openml.org/t/.*"] # FIXME: to avoid test server bugs avoiding docs building # -- Options for HTML output ---------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for From 9671c39fd079f9d9c6f1681a81bf037277c14cc9 Mon Sep 17 00:00:00 2001 From: Lennart Purucker Date: Fri, 12 Jan 2024 11:25:22 +0100 Subject: [PATCH 26/28] add special case for class labels that are dtype string --- openml/datasets/dataset.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/openml/datasets/dataset.py b/openml/datasets/dataset.py index f81ddd23a..69535155f 100644 --- a/openml/datasets/dataset.py +++ b/openml/datasets/dataset.py @@ -908,8 +908,18 @@ def retrieve_class_labels(self, target_name: str = "class") -> None | list[str]: list """ for feature in self.features.values(): - if (feature.name == target_name) and (feature.data_type == "nominal"): - return feature.nominal_values + if feature.name == target_name: + if feature.data_type == "nominal": + return feature.nominal_values + + if feature.data_type == "string": + # Rel.: #1311 + # The target is invalid for a classification task if the feature type is string + # and not nominal. For such miss-configured tasks, we silently fix it here as + # we can safely interpreter string as nominal. + df, *_ = self.get_data() + return list(df.loc[feature.name].unique()) + return None def get_features_by_type( # noqa: C901 From 1097b7b82c6a8f569e339f0a23e9ca7b95f7dae6 Mon Sep 17 00:00:00 2001 From: Lennart Purucker Date: Fri, 12 Jan 2024 11:36:32 +0100 Subject: [PATCH 27/28] fix bug and add test --- openml/datasets/dataset.py | 2 +- tests/test_datasets/test_dataset_functions.py | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/openml/datasets/dataset.py b/openml/datasets/dataset.py index 69535155f..107827238 100644 --- a/openml/datasets/dataset.py +++ b/openml/datasets/dataset.py @@ -918,7 +918,7 @@ def retrieve_class_labels(self, target_name: str = "class") -> None | list[str]: # and not nominal. For such miss-configured tasks, we silently fix it here as # we can safely interpreter string as nominal. df, *_ = self.get_data() - return list(df.loc[feature.name].unique()) + return list(df[feature.name].unique()) return None diff --git a/tests/test_datasets/test_dataset_functions.py b/tests/test_datasets/test_dataset_functions.py index 9fbb9259a..6a6672a5b 100644 --- a/tests/test_datasets/test_dataset_functions.py +++ b/tests/test_datasets/test_dataset_functions.py @@ -631,6 +631,12 @@ def test__retrieve_class_labels(self): ) assert labels == ["C", "H", "G"] + # Test workaround for string-typed class labels + custom_ds = openml.datasets.get_dataset(2, download_data=False) + custom_ds.features[31].data_type = "string" + labels = custom_ds.retrieve_class_labels(target_name=custom_ds.features[31].name) + assert labels == ["COIL", "SHEET"] + def test_upload_dataset_with_url(self): dataset = OpenMLDataset( "%s-UploadTestWithURL" % self._get_sentinel(), From f3c520de60714e7699cdcf689155ba4f8ee1aba6 Mon Sep 17 00:00:00 2001 From: Lennart Purucker Date: Fri, 12 Jan 2024 11:36:49 +0100 Subject: [PATCH 28/28] formatting --- tests/test_datasets/test_dataset_functions.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_datasets/test_dataset_functions.py b/tests/test_datasets/test_dataset_functions.py index 6a6672a5b..f3d269dc1 100644 --- a/tests/test_datasets/test_dataset_functions.py +++ b/tests/test_datasets/test_dataset_functions.py @@ -626,6 +626,7 @@ def test__retrieve_class_labels(self): openml.config.set_root_cache_directory(self.static_cache_dir) labels = openml.datasets.get_dataset(2, download_data=False).retrieve_class_labels() assert labels == ["1", "2", "3", "4", "5", "U"] + labels = openml.datasets.get_dataset(2, download_data=False).retrieve_class_labels( target_name="product-type", )