diff --git a/openml/extensions/sklearn/extension.py b/openml/extensions/sklearn/extension.py index ce8e4ebf9..9efd8fbb0 100644 --- a/openml/extensions/sklearn/extension.py +++ b/openml/extensions/sklearn/extension.py @@ -432,6 +432,7 @@ def _serialize_model(self, model: Any) -> OpenMLFlow: # annotate a class of sklearn.svm.SVC() with the # tag svm? ], + extension=self, language='English', # TODO fill in dependencies! dependencies=dependencies) @@ -455,9 +456,12 @@ def _get_external_version_string( model_package_name, model_package_version_number, ) openml_version = self._format_external_version('openml', openml.__version__) + sklearn_version = self._format_external_version('sklearn', sklearn.__version__) + external_versions = set() external_versions.add(external_version) external_versions.add(openml_version) + external_versions.add(sklearn_version) for visitee in sub_components.values(): for external_version in visitee.external_version.split(','): external_versions.add(external_version) diff --git a/openml/flows/flow.py b/openml/flows/flow.py index bdd4fe6a6..33102f9d4 100644 --- a/openml/flows/flow.py +++ b/openml/flows/flow.py @@ -87,7 +87,7 @@ def __init__(self, name, description, model, components, parameters, dependencies, class_name=None, custom_name=None, binary_url=None, binary_format=None, binary_md5=None, uploader=None, upload_date=None, - flow_id=None, version=None): + flow_id=None, extension=None, version=None): self.name = name self.description = description self.model = model @@ -131,8 +131,10 @@ def __init__(self, name, description, model, components, parameters, self.language = language self.dependencies = dependencies self.flow_id = flow_id - - self._extension = get_extension_by_flow(self) + if extension is None: + self._extension = get_extension_by_flow(self) + else: + self._extension = extension @property def extension(self): diff --git a/tests/test_extensions/test_sklearn_extension/test_sklearn_extension.py b/tests/test_extensions/test_sklearn_extension/test_sklearn_extension.py index 2217b332b..bcebe417f 100644 --- a/tests/test_extensions/test_sklearn_extension/test_sklearn_extension.py +++ b/tests/test_extensions/test_sklearn_extension/test_sklearn_extension.py @@ -1219,6 +1219,14 @@ def setUp(self): ################################################################################################ # Test methods for performing runs with this extension module + def test_run_model_on_task(self): + class MyPipe(sklearn.pipeline.Pipeline): + pass + task = openml.tasks.get_task(1) + pipe = MyPipe([('imp', Imputer()), + ('dummy', sklearn.dummy.DummyClassifier())]) + openml.runs.run_model_on_task(pipe, task) + def test_seed_model(self): # randomized models that are initialized without seeds, can be seeded randomized_clfs = [