Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion openml/extensions/extension_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@ def can_handle_model(cls, model: Any) -> bool:
# Abstract methods for flow serialization and de-serialization

@abstractmethod
def flow_to_model(self, flow: 'OpenMLFlow', initialize_with_defaults: bool = False) -> Any:
def flow_to_model(self, flow: 'OpenMLFlow',
initialize_with_defaults: bool = False,
strict_version: bool = True) -> Any:
"""Instantiate a model from the flow representation.

Parameters
Expand All @@ -69,6 +71,9 @@ def flow_to_model(self, flow: 'OpenMLFlow', initialize_with_defaults: bool = Fal
If this flag is set, the hyperparameter values of flows will be
ignored and a flow with its defaults is returned.

strict_version : bool, default=True
Whether to fail if version requirements are not fulfilled.

Returns
-------
Any
Expand Down
47 changes: 38 additions & 9 deletions openml/extensions/sklearn/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,9 @@ def remove_all_in_parentheses(string: str) -> str:
################################################################################################
# Methods for flow serialization and de-serialization

def flow_to_model(self, flow: 'OpenMLFlow', initialize_with_defaults: bool = False) -> Any:
def flow_to_model(self, flow: 'OpenMLFlow',
initialize_with_defaults: bool = False,
strict_version: bool = True) -> Any:
"""Initializes a sklearn model based on a flow.

Parameters
Expand All @@ -219,18 +221,24 @@ def flow_to_model(self, flow: 'OpenMLFlow', initialize_with_defaults: bool = Fal
If this flag is set, the hyperparameter values of flows will be
ignored and a flow with its defaults is returned.

strict_version : bool, default=True
Whether to fail if version requirements are not fulfilled.

Returns
-------
mixed
"""
return self._deserialize_sklearn(flow, initialize_with_defaults=initialize_with_defaults)
return self._deserialize_sklearn(
flow, initialize_with_defaults=initialize_with_defaults,
strict_version=strict_version)

def _deserialize_sklearn(
self,
o: Any,
components: Optional[Dict] = None,
initialize_with_defaults: bool = False,
recursion_depth: int = 0,
strict_version: bool = True,
) -> Any:
"""Recursive function to deserialize a scikit-learn flow.

Expand All @@ -254,6 +262,9 @@ def _deserialize_sklearn(
The depth at which this flow is called, mostly for debugging
purposes

strict_version : bool, default=True
Whether to fail if version requirements are not fulfilled.

Returns
-------
mixed
Expand Down Expand Up @@ -290,13 +301,15 @@ def _deserialize_sklearn(
rval = self._deserialize_function(value)
elif serialized_type == 'component_reference':
assert components is not None # Necessary for mypy
value = self._deserialize_sklearn(value, recursion_depth=depth_pp)
value = self._deserialize_sklearn(value, recursion_depth=depth_pp,
strict_version=strict_version)
step_name = value['step_name']
key = value['key']
component = self._deserialize_sklearn(
components[key],
initialize_with_defaults=initialize_with_defaults,
recursion_depth=depth_pp
recursion_depth=depth_pp,
strict_version=strict_version,
)
# The component is now added to where it should be used
# later. It should not be passed to the constructor of the
Expand All @@ -310,7 +323,8 @@ def _deserialize_sklearn(
rval = (step_name, component, value['argument_1'])
elif serialized_type == 'cv_object':
rval = self._deserialize_cross_validator(
value, recursion_depth=recursion_depth
value, recursion_depth=recursion_depth,
strict_version=strict_version
)
else:
raise ValueError('Cannot flow_to_sklearn %s' % serialized_type)
Expand All @@ -323,12 +337,14 @@ def _deserialize_sklearn(
components=components,
initialize_with_defaults=initialize_with_defaults,
recursion_depth=depth_pp,
strict_version=strict_version
),
self._deserialize_sklearn(
o=value,
components=components,
initialize_with_defaults=initialize_with_defaults,
recursion_depth=depth_pp,
strict_version=strict_version
)
)
for key, value in sorted(o.items())
Expand All @@ -340,6 +356,7 @@ def _deserialize_sklearn(
components=components,
initialize_with_defaults=initialize_with_defaults,
recursion_depth=depth_pp,
strict_version=strict_version
)
for element in o
]
Expand All @@ -354,6 +371,7 @@ def _deserialize_sklearn(
flow=o,
keep_defaults=initialize_with_defaults,
recursion_depth=recursion_depth,
strict_version=strict_version
)
else:
raise TypeError(o)
Expand Down Expand Up @@ -779,10 +797,12 @@ def _deserialize_model(
flow: OpenMLFlow,
keep_defaults: bool,
recursion_depth: int,
strict_version: bool = True
) -> Any:
logging.info('-%s deserialize %s' % ('-' * recursion_depth, flow.name))
model_name = flow.class_name
self._check_dependencies(flow.dependencies)
self._check_dependencies(flow.dependencies,
strict_version=strict_version)

parameters = flow.parameters
components = flow.components
Expand All @@ -804,6 +824,7 @@ def _deserialize_model(
components=components_,
initialize_with_defaults=keep_defaults,
recursion_depth=recursion_depth + 1,
strict_version=strict_version,
)
parameter_dict[name] = rval

Expand All @@ -818,6 +839,7 @@ def _deserialize_model(
rval = self._deserialize_sklearn(
value,
recursion_depth=recursion_depth + 1,
strict_version=strict_version
)
parameter_dict[name] = rval

Expand All @@ -843,7 +865,8 @@ def _deserialize_model(
del parameter_dict[param]
return model_class(**parameter_dict)

def _check_dependencies(self, dependencies: str) -> None:
def _check_dependencies(self, dependencies: str,
strict_version: bool = True) -> None:
if not dependencies:
return

Expand Down Expand Up @@ -871,9 +894,13 @@ def _check_dependencies(self, dependencies: str) -> None:
else:
raise NotImplementedError(
'operation \'%s\' is not supported' % operation)
message = ('Trying to deserialize a model with dependency '
'%s not satisfied.' % dependency_string)
if not check:
raise ValueError('Trying to deserialize a model with dependency '
'%s not satisfied.' % dependency_string)
if strict_version:
raise ValueError(message)
else:
warnings.warn(message)

def _serialize_type(self, o: Any) -> 'OrderedDict[str, str]':
mapping = {float: 'float',
Expand Down Expand Up @@ -991,6 +1018,7 @@ def _deserialize_cross_validator(
self,
value: 'OrderedDict[str, Any]',
recursion_depth: int,
strict_version: bool = True
) -> Any:
model_name = value['name']
parameters = value['parameters']
Expand All @@ -1002,6 +1030,7 @@ def _deserialize_cross_validator(
parameters[parameter] = self._deserialize_sklearn(
parameters[parameter],
recursion_depth=recursion_depth + 1,
strict_version=strict_version
)
return model_class(**parameters)

Expand Down
14 changes: 12 additions & 2 deletions openml/flows/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ def _get_cached_flow(fid: int) -> OpenMLFlow:


@openml.utils.thread_safe_if_oslo_installed
def get_flow(flow_id: int, reinstantiate: bool = False) -> OpenMLFlow:
def get_flow(flow_id: int, reinstantiate: bool = False,
strict_version: bool = True) -> OpenMLFlow:
"""Download the OpenML flow for a given flow ID.

Parameters
Expand All @@ -82,6 +83,9 @@ def get_flow(flow_id: int, reinstantiate: bool = False) -> OpenMLFlow:
reinstantiate: bool
Whether to reinstantiate the flow to a model instance.

strict_version : bool, default=True
Whether to fail if version requirements are not fulfilled.

Returns
-------
flow : OpenMLFlow
Expand All @@ -91,7 +95,13 @@ def get_flow(flow_id: int, reinstantiate: bool = False) -> OpenMLFlow:
flow = _get_flow_description(flow_id)

if reinstantiate:
flow.model = flow.extension.flow_to_model(flow)
flow.model = flow.extension.flow_to_model(
flow, strict_version=strict_version)
if not strict_version:
# check if we need to return a new flow b/c of version mismatch
new_flow = flow.extension.model_to_flow(flow.model)
if new_flow.dependencies != flow.dependencies:
return new_flow
return flow


Expand Down
10 changes: 9 additions & 1 deletion tests/test_flows/test_flow_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,9 +290,17 @@ def test_get_flow_reinstantiate_model_wrong_version(self):
openml.config.server = self.production_server
_, sklearn_major, _ = LooseVersion(sklearn.__version__).version[:3]
flow = 8175
expected = 'Trying to deserialize a model with dependency sklearn==0.19.1 not satisfied.'
expected = ('Trying to deserialize a model with dependency'
' sklearn==0.19.1 not satisfied.')
self.assertRaisesRegex(ValueError,
expected,
openml.flows.get_flow,
flow_id=flow,
reinstantiate=True)
if LooseVersion(sklearn.__version__) > "0.19.1":
# 0.18 actually can't deserialize this because of incompatibility
flow = openml.flows.get_flow(flow_id=flow, reinstantiate=True,
strict_version=False)
# ensure that a new flow was created
assert flow.flow_id is None
assert "0.19.1" not in flow.dependencies