diff --git a/openml/extensions/extension_interface.py b/openml/extensions/extension_interface.py index 6346cb0bf..d963edb1b 100644 --- a/openml/extensions/extension_interface.py +++ b/openml/extensions/extension_interface.py @@ -58,7 +58,9 @@ def can_handle_model(cls, model: Any) -> bool: # Abstract methods for flow serialization and de-serialization @abstractmethod - def flow_to_model(self, flow: 'OpenMLFlow', initialize_with_defaults: bool = False) -> Any: + def flow_to_model(self, flow: 'OpenMLFlow', + initialize_with_defaults: bool = False, + strict_version: bool = True) -> Any: """Instantiate a model from the flow representation. Parameters @@ -69,6 +71,9 @@ def flow_to_model(self, flow: 'OpenMLFlow', initialize_with_defaults: bool = Fal If this flag is set, the hyperparameter values of flows will be ignored and a flow with its defaults is returned. + strict_version : bool, default=True + Whether to fail if version requirements are not fulfilled. + Returns ------- Any diff --git a/openml/extensions/sklearn/extension.py b/openml/extensions/sklearn/extension.py index d44b61ae7..8ed13bb29 100644 --- a/openml/extensions/sklearn/extension.py +++ b/openml/extensions/sklearn/extension.py @@ -206,7 +206,9 @@ def remove_all_in_parentheses(string: str) -> str: ################################################################################################ # Methods for flow serialization and de-serialization - def flow_to_model(self, flow: 'OpenMLFlow', initialize_with_defaults: bool = False) -> Any: + def flow_to_model(self, flow: 'OpenMLFlow', + initialize_with_defaults: bool = False, + strict_version: bool = True) -> Any: """Initializes a sklearn model based on a flow. Parameters @@ -219,11 +221,16 @@ def flow_to_model(self, flow: 'OpenMLFlow', initialize_with_defaults: bool = Fal If this flag is set, the hyperparameter values of flows will be ignored and a flow with its defaults is returned. + strict_version : bool, default=True + Whether to fail if version requirements are not fulfilled. + Returns ------- mixed """ - return self._deserialize_sklearn(flow, initialize_with_defaults=initialize_with_defaults) + return self._deserialize_sklearn( + flow, initialize_with_defaults=initialize_with_defaults, + strict_version=strict_version) def _deserialize_sklearn( self, @@ -231,6 +238,7 @@ def _deserialize_sklearn( components: Optional[Dict] = None, initialize_with_defaults: bool = False, recursion_depth: int = 0, + strict_version: bool = True, ) -> Any: """Recursive function to deserialize a scikit-learn flow. @@ -254,6 +262,9 @@ def _deserialize_sklearn( The depth at which this flow is called, mostly for debugging purposes + strict_version : bool, default=True + Whether to fail if version requirements are not fulfilled. + Returns ------- mixed @@ -290,13 +301,15 @@ def _deserialize_sklearn( rval = self._deserialize_function(value) elif serialized_type == 'component_reference': assert components is not None # Necessary for mypy - value = self._deserialize_sklearn(value, recursion_depth=depth_pp) + value = self._deserialize_sklearn(value, recursion_depth=depth_pp, + strict_version=strict_version) step_name = value['step_name'] key = value['key'] component = self._deserialize_sklearn( components[key], initialize_with_defaults=initialize_with_defaults, - recursion_depth=depth_pp + recursion_depth=depth_pp, + strict_version=strict_version, ) # The component is now added to where it should be used # later. It should not be passed to the constructor of the @@ -310,7 +323,8 @@ def _deserialize_sklearn( rval = (step_name, component, value['argument_1']) elif serialized_type == 'cv_object': rval = self._deserialize_cross_validator( - value, recursion_depth=recursion_depth + value, recursion_depth=recursion_depth, + strict_version=strict_version ) else: raise ValueError('Cannot flow_to_sklearn %s' % serialized_type) @@ -323,12 +337,14 @@ def _deserialize_sklearn( components=components, initialize_with_defaults=initialize_with_defaults, recursion_depth=depth_pp, + strict_version=strict_version ), self._deserialize_sklearn( o=value, components=components, initialize_with_defaults=initialize_with_defaults, recursion_depth=depth_pp, + strict_version=strict_version ) ) for key, value in sorted(o.items()) @@ -340,6 +356,7 @@ def _deserialize_sklearn( components=components, initialize_with_defaults=initialize_with_defaults, recursion_depth=depth_pp, + strict_version=strict_version ) for element in o ] @@ -354,6 +371,7 @@ def _deserialize_sklearn( flow=o, keep_defaults=initialize_with_defaults, recursion_depth=recursion_depth, + strict_version=strict_version ) else: raise TypeError(o) @@ -779,10 +797,12 @@ def _deserialize_model( flow: OpenMLFlow, keep_defaults: bool, recursion_depth: int, + strict_version: bool = True ) -> Any: logging.info('-%s deserialize %s' % ('-' * recursion_depth, flow.name)) model_name = flow.class_name - self._check_dependencies(flow.dependencies) + self._check_dependencies(flow.dependencies, + strict_version=strict_version) parameters = flow.parameters components = flow.components @@ -804,6 +824,7 @@ def _deserialize_model( components=components_, initialize_with_defaults=keep_defaults, recursion_depth=recursion_depth + 1, + strict_version=strict_version, ) parameter_dict[name] = rval @@ -818,6 +839,7 @@ def _deserialize_model( rval = self._deserialize_sklearn( value, recursion_depth=recursion_depth + 1, + strict_version=strict_version ) parameter_dict[name] = rval @@ -843,7 +865,8 @@ def _deserialize_model( del parameter_dict[param] return model_class(**parameter_dict) - def _check_dependencies(self, dependencies: str) -> None: + def _check_dependencies(self, dependencies: str, + strict_version: bool = True) -> None: if not dependencies: return @@ -871,9 +894,13 @@ def _check_dependencies(self, dependencies: str) -> None: else: raise NotImplementedError( 'operation \'%s\' is not supported' % operation) + message = ('Trying to deserialize a model with dependency ' + '%s not satisfied.' % dependency_string) if not check: - raise ValueError('Trying to deserialize a model with dependency ' - '%s not satisfied.' % dependency_string) + if strict_version: + raise ValueError(message) + else: + warnings.warn(message) def _serialize_type(self, o: Any) -> 'OrderedDict[str, str]': mapping = {float: 'float', @@ -991,6 +1018,7 @@ def _deserialize_cross_validator( self, value: 'OrderedDict[str, Any]', recursion_depth: int, + strict_version: bool = True ) -> Any: model_name = value['name'] parameters = value['parameters'] @@ -1002,6 +1030,7 @@ def _deserialize_cross_validator( parameters[parameter] = self._deserialize_sklearn( parameters[parameter], recursion_depth=recursion_depth + 1, + strict_version=strict_version ) return model_class(**parameters) diff --git a/openml/flows/functions.py b/openml/flows/functions.py index d12bcfe91..2b327f6be 100644 --- a/openml/flows/functions.py +++ b/openml/flows/functions.py @@ -71,7 +71,8 @@ def _get_cached_flow(fid: int) -> OpenMLFlow: @openml.utils.thread_safe_if_oslo_installed -def get_flow(flow_id: int, reinstantiate: bool = False) -> OpenMLFlow: +def get_flow(flow_id: int, reinstantiate: bool = False, + strict_version: bool = True) -> OpenMLFlow: """Download the OpenML flow for a given flow ID. Parameters @@ -82,6 +83,9 @@ def get_flow(flow_id: int, reinstantiate: bool = False) -> OpenMLFlow: reinstantiate: bool Whether to reinstantiate the flow to a model instance. + strict_version : bool, default=True + Whether to fail if version requirements are not fulfilled. + Returns ------- flow : OpenMLFlow @@ -91,7 +95,13 @@ def get_flow(flow_id: int, reinstantiate: bool = False) -> OpenMLFlow: flow = _get_flow_description(flow_id) if reinstantiate: - flow.model = flow.extension.flow_to_model(flow) + flow.model = flow.extension.flow_to_model( + flow, strict_version=strict_version) + if not strict_version: + # check if we need to return a new flow b/c of version mismatch + new_flow = flow.extension.model_to_flow(flow.model) + if new_flow.dependencies != flow.dependencies: + return new_flow return flow diff --git a/tests/test_flows/test_flow_functions.py b/tests/test_flows/test_flow_functions.py index 95b4fa3f0..941cb6a90 100644 --- a/tests/test_flows/test_flow_functions.py +++ b/tests/test_flows/test_flow_functions.py @@ -290,9 +290,17 @@ def test_get_flow_reinstantiate_model_wrong_version(self): openml.config.server = self.production_server _, sklearn_major, _ = LooseVersion(sklearn.__version__).version[:3] flow = 8175 - expected = 'Trying to deserialize a model with dependency sklearn==0.19.1 not satisfied.' + expected = ('Trying to deserialize a model with dependency' + ' sklearn==0.19.1 not satisfied.') self.assertRaisesRegex(ValueError, expected, openml.flows.get_flow, flow_id=flow, reinstantiate=True) + if LooseVersion(sklearn.__version__) > "0.19.1": + # 0.18 actually can't deserialize this because of incompatibility + flow = openml.flows.get_flow(flow_id=flow, reinstantiate=True, + strict_version=False) + # ensure that a new flow was created + assert flow.flow_id is None + assert "0.19.1" not in flow.dependencies