From 61024e66e5454df057af72cc7cd06611800a4023 Mon Sep 17 00:00:00 2001 From: ykobayashi Date: Thu, 21 Nov 2019 17:48:40 +0900 Subject: [PATCH 1/3] Add support for using run_model_on_task simply --- openml/runs/functions.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/openml/runs/functions.py b/openml/runs/functions.py index 9e7321d45..4868b0abf 100644 --- a/openml/runs/functions.py +++ b/openml/runs/functions.py @@ -25,7 +25,7 @@ OpenMLRegressionTask, OpenMLSupervisedTask, OpenMLLearningCurveTask from .run import OpenMLRun from .trace import OpenMLRunTrace -from ..tasks import TaskTypeEnum +from ..tasks import TaskTypeEnum, get_task # Avoid import cycles: https://mypy.readthedocs.io/en/latest/common_issues.html#import-cycles if TYPE_CHECKING: @@ -38,7 +38,7 @@ def run_model_on_task( model: Any, - task: OpenMLTask, + task: Union[int, str, OpenMLTask], avoid_duplicate_runs: bool = True, flow_tags: List[str] = None, seed: int = None, @@ -54,8 +54,9 @@ def run_model_on_task( A model which has a function fit(X,Y) and predict(X), all supervised estimators of scikit learn follow this definition of a model [1] [1](http://scikit-learn.org/stable/tutorial/statistical_inference/supervised_learning.html) - task : OpenMLTask - Task to perform. This may be a model instead if the first argument is an OpenMLTask. + task : OpenMLTask or int or str + Task to perform or Task id. + This may be a model instead if the first argument is an OpenMLTask. avoid_duplicate_runs : bool, optional (default=True) If True, the run will throw an error if the setup/task combination is already present on the server. This feature requires an internet connection. @@ -84,7 +85,7 @@ def run_model_on_task( # Flexibility currently still allowed due to code-snippet in OpenML100 paper (3-2019). # When removing this please also remove the method `is_estimator` from the extension # interface as it is only used here (MF, 3-2019) - if isinstance(model, OpenMLTask): + if isinstance(model, (int, str, OpenMLTask)): warnings.warn("The old argument order (task, model) is deprecated and " "will not be supported in the future. Please use the " "order (model, task).", DeprecationWarning) @@ -98,6 +99,9 @@ def run_model_on_task( flow = extension.model_to_flow(model) + if isinstance(task, (int, str)): + task = get_task(int(task)) + run = run_flow_on_task( task=task, flow=flow, From cb16c2702dd6c3310cb9c84e3858d4a7a7d5b140 Mon Sep 17 00:00:00 2001 From: ykobayashi Date: Thu, 21 Nov 2019 18:25:29 +0900 Subject: [PATCH 2/3] Add unit test --- tests/test_runs/test_run_functions.py | 38 +++++++++++++++++++-------- 1 file changed, 27 insertions(+), 11 deletions(-) diff --git a/tests/test_runs/test_run_functions.py b/tests/test_runs/test_run_functions.py index fe8aab808..854061148 100644 --- a/tests/test_runs/test_run_functions.py +++ b/tests/test_runs/test_run_functions.py @@ -110,9 +110,9 @@ def _compare_predictions(self, predictions, predictions_prime): return True - def _rerun_model_and_compare_predictions(self, run_id, model_prime, seed): + def _rerun_model_and_compare_predictions(self, run_id, model_prime, seed, + create_task_obj): run = openml.runs.get_run(run_id) - task = openml.tasks.get_task(run.task_id) # TODO: assert holdout task @@ -121,12 +121,24 @@ def _rerun_model_and_compare_predictions(self, run_id, model_prime, seed): predictions_url = openml._api_calls._file_id_to_url(file_id) response = openml._api_calls._download_text_file(predictions_url) predictions = arff.loads(response) - run_prime = openml.runs.run_model_on_task( - model=model_prime, - task=task, - avoid_duplicate_runs=False, - seed=seed, - ) + + # if create_task_obj=False, task argument in run_model_on_task is specified task_id + if create_task_obj: + task = openml.tasks.get_task(run.task_id) + run_prime = openml.runs.run_model_on_task( + model=model_prime, + task=task, + avoid_duplicate_runs=False, + seed=seed, + ) + else: + run_prime = openml.runs.run_model_on_task( + model=model_prime, + task=run.task_id, + avoid_duplicate_runs=False, + seed=seed, + ) + predictions_prime = run_prime._generate_arff_dict() self._compare_predictions(predictions, predictions_prime) @@ -425,13 +437,17 @@ def determine_grid_size(param_grid): raise e self._rerun_model_and_compare_predictions(run.run_id, model_prime, - seed) + seed, create_task_obj=True) + self._rerun_model_and_compare_predictions(run.run_id, model_prime, + seed, create_task_obj=False) else: run_downloaded = openml.runs.get_run(run.run_id) sid = run_downloaded.setup_id model_prime = openml.setups.initialize_model(sid) - self._rerun_model_and_compare_predictions(run.run_id, - model_prime, seed) + self._rerun_model_and_compare_predictions(run.run_id, model_prime, + seed, create_task_obj=True) + self._rerun_model_and_compare_predictions(run.run_id, model_prime, + seed, create_task_obj=False) # todo: check if runtime is present self._check_fold_timing_evaluations(run.fold_evaluations, 1, num_folds, From 39d6baa83bd71750583d4348d5d9fd8f2c8ffd99 Mon Sep 17 00:00:00 2001 From: ykobayashi Date: Thu, 21 Nov 2019 19:17:15 +0900 Subject: [PATCH 3/3] fix mypy error --- openml/runs/functions.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/openml/runs/functions.py b/openml/runs/functions.py index 4868b0abf..ddaf3b028 100644 --- a/openml/runs/functions.py +++ b/openml/runs/functions.py @@ -99,8 +99,13 @@ def run_model_on_task( flow = extension.model_to_flow(model) - if isinstance(task, (int, str)): - task = get_task(int(task)) + def get_task_and_type_conversion(task: Union[int, str, OpenMLTask]) -> OpenMLTask: + if isinstance(task, (int, str)): + return get_task(int(task)) + else: + return task + + task = get_task_and_type_conversion(task) run = run_flow_on_task( task=task,