diff --git a/src/citrine/__version__.py b/src/citrine/__version__.py index 10c5c0dd1..f87b7d84c 100644 --- a/src/citrine/__version__.py +++ b/src/citrine/__version__.py @@ -1 +1 @@ -__version__ = "3.20.0" +__version__ = "3.21.0" diff --git a/src/citrine/gemd_queries/criteria.py b/src/citrine/gemd_queries/criteria.py index 13f66a305..8c32dffb3 100644 --- a/src/citrine/gemd_queries/criteria.py +++ b/src/citrine/gemd_queries/criteria.py @@ -183,6 +183,7 @@ class TagsCriteria(Serializable['TagsCriteria'], Criteria): - AND_TAGS_FILTER_TYPE: All specified tags must be present - OR_TAGS_FILTER_TYPE: At least one of the specified tags must be present - NOT_TAGS_FILTER_TYPE: None of the specified tags should be present + """ tags = properties.Set(properties.String, 'tags') @@ -200,6 +201,7 @@ class ConnectivityClassCriteria(Serializable['ConnectivityClassCriteria'], Crite Whether the material is consumed. is_produced: Optional[bool] Whether the material is produced. + """ is_consumed = properties.Optional(properties.Boolean, 'is_consumed') diff --git a/src/citrine/informatics/design_spaces/__init__.py b/src/citrine/informatics/design_spaces/__init__.py index 80b9e74ee..0b4804802 100644 --- a/src/citrine/informatics/design_spaces/__init__.py +++ b/src/citrine/informatics/design_spaces/__init__.py @@ -1,6 +1,7 @@ # flake8: noqa from .data_source_design_space import * from .design_space import * +from .design_space_settings import * from .enumerated_design_space import * from .formulation_design_space import * from .product_design_space import * diff --git a/src/citrine/informatics/design_spaces/design_space_settings.py b/src/citrine/informatics/design_spaces/design_space_settings.py new file mode 100644 index 000000000..3c28ef0bd --- /dev/null +++ b/src/citrine/informatics/design_spaces/design_space_settings.py @@ -0,0 +1,64 @@ +from typing import Optional, Union +from uuid import UUID + +from gemd.enumeration.base_enumeration import BaseEnumeration + +from citrine._rest.resource import Resource +from citrine._serialization import properties + + +__all__ = ["DefaultDesignSpaceMode", "DesignSpaceSettings"] + + +class DefaultDesignSpaceMode(BaseEnumeration): + """The type of default design space to create. + + * ATTRIBUTE results in a product design space containing dimensions required by the predictor + * HIERARCHICAL results in a hierarchical design space resembling the shape of training data + """ + + ATTRIBUTE = 'ATTRIBUTE' + HIERARCHICAL = 'HIERARCHICAL' + + +class DesignSpaceSettings(Resource["DesignSpaceSettings"]): + """The configuration used to produce a default design space.""" + + predictor_id = properties.UUID("predictor_id") + predictor_version = properties.Optional( + properties.Union([properties.Integer(), properties.String()]), + 'predictor_version' + ) + mode = properties.Optional(properties.Enumeration(DefaultDesignSpaceMode), "mode") + exclude_intermediates = properties.Optional(properties.Boolean(), "exclude_intermediates") + include_ingredient_fraction_constraints = properties.Optional( + properties.Boolean(), "include_ingredient_fraction_constraints" + ) + include_label_fraction_constraints = properties.Optional( + properties.Boolean(), "include_label_fraction_constraints" + ) + include_label_count_constraints = properties.Optional( + properties.Boolean(), "include_label_count_constraints" + ) + include_parameter_constraints = properties.Optional( + properties.Boolean(), "include_parameter_constraints" + ) + + def __init__(self, + *, + predictor_id: Union[UUID, str], + predictor_version: Optional[Union[int, str]] = None, + mode: Optional[DefaultDesignSpaceMode] = None, + exclude_intermediates: Optional[bool] = None, + include_ingredient_fraction_constraints: Optional[bool] = None, + include_label_fraction_constraints: Optional[bool] = None, + include_label_count_constraints: Optional[bool] = None, + include_parameter_constraints: Optional[bool] = None): + self.predictor_id = predictor_id + self.predictor_version = predictor_version + self.mode = mode + self.exclude_intermediates = exclude_intermediates + self.include_ingredient_fraction_constraints = include_ingredient_fraction_constraints + self.include_label_fraction_constraints = include_label_fraction_constraints + self.include_label_count_constraints = include_label_count_constraints + self.include_parameter_constraints = include_parameter_constraints diff --git a/src/citrine/informatics/design_spaces/hierarchical_design_space.py b/src/citrine/informatics/design_spaces/hierarchical_design_space.py index 3a7d7b6e5..205441820 100644 --- a/src/citrine/informatics/design_spaces/hierarchical_design_space.py +++ b/src/citrine/informatics/design_spaces/hierarchical_design_space.py @@ -8,6 +8,7 @@ from citrine.informatics.dimensions import Dimension from citrine.informatics.design_spaces import FormulationDesignSpace from citrine.informatics.design_spaces.design_space import DesignSpace +from citrine.informatics.design_spaces.design_space_settings import DesignSpaceSettings __all__ = [ "TemplateLink", @@ -150,6 +151,8 @@ class HierarchicalDesignSpace(EngineResource["HierarchicalDesignSpace"], DesignS """ + _settings = properties.Optional(properties.Object(DesignSpaceSettings), "metadata.settings") + root = properties.Object(MaterialNodeDefinition, "data.instance.root") subspaces = properties.List( properties.Object(MaterialNodeDefinition), "data.instance.subspaces" @@ -179,6 +182,9 @@ def __init__( def _post_dump(self, data: dict) -> dict: data = super()._post_dump(data) + if self._settings: + data["settings"] = self._settings.dump() + root_node = data["instance"]["root"] data["instance"]["root"] = self.__unwrap_node(root_node) diff --git a/src/citrine/informatics/design_spaces/product_design_space.py b/src/citrine/informatics/design_spaces/product_design_space.py index bd7ca704d..d52f6a640 100644 --- a/src/citrine/informatics/design_spaces/product_design_space.py +++ b/src/citrine/informatics/design_spaces/product_design_space.py @@ -4,6 +4,7 @@ from citrine._rest.engine_resource import EngineResource from citrine._serialization import properties from citrine.informatics.design_spaces.design_space import DesignSpace +from citrine.informatics.design_spaces.design_space_settings import DesignSpaceSettings from citrine.informatics.dimensions import Dimension __all__ = ['ProductDesignSpace'] @@ -28,6 +29,8 @@ class ProductDesignSpace(EngineResource['ProductDesignSpace'], DesignSpace): """ + _settings = properties.Optional(properties.Object(DesignSpaceSettings), "metadata.settings") + subspaces = properties.List(properties.Object(DesignSpace), 'data.instance.subspaces', default=[]) dimensions = properties.Optional( @@ -50,6 +53,10 @@ def __init__(self, def _post_dump(self, data: dict) -> dict: data = super()._post_dump(data) + + if self._settings: + data["settings"] = self._settings.dump() + for i, subspace in enumerate(data['instance']['subspaces']): if isinstance(subspace, dict): # embedded design spaces are not modules, so only serialize their config diff --git a/src/citrine/resources/design_space.py b/src/citrine/resources/design_space.py index 17eb35a21..add28612e 100644 --- a/src/citrine/resources/design_space.py +++ b/src/citrine/resources/design_space.py @@ -3,28 +3,16 @@ from typing import Iterable, Optional, TypeVar, Union from uuid import UUID -from gemd.enumeration.base_enumeration import BaseEnumeration from citrine._utils.functions import format_escaped_url -from citrine.informatics.design_spaces import DesignSpace, EnumeratedDesignSpace, \ - HierarchicalDesignSpace +from citrine.informatics.design_spaces import DefaultDesignSpaceMode, DesignSpace, \ + DesignSpaceSettings, EnumeratedDesignSpace, HierarchicalDesignSpace from citrine._rest.collection import Collection from citrine._session import Session CreationType = TypeVar('CreationType', bound=DesignSpace) -class DefaultDesignSpaceMode(BaseEnumeration): - """The type of default design space to create. - - * ATTRIBUTE results in a product design space containing dimensions required by the predictor - * HIERARCHICAL results in a hierarchical design space resembling the shape of training data - """ - - ATTRIBUTE = 'ATTRIBUTE' - HIERARCHICAL = 'HIERARCHICAL' - - class DesignSpaceCollection(Collection[DesignSpace]): """Represents the collection of design spaces as well as the resources belonging to it. @@ -154,7 +142,7 @@ def list_archived(self, *, per_page: int = 20) -> Iterable[DesignSpace]: def create_default(self, *, - predictor_id: UUID, + predictor_id: Union[UUID, str], predictor_version: Optional[Union[int, str]] = None, mode: DefaultDesignSpaceMode = DefaultDesignSpaceMode.ATTRIBUTE, include_ingredient_fraction_constraints: bool = False, @@ -209,19 +197,20 @@ def create_default(self, """ path = f'projects/{self.project_id}/design-spaces/default' - payload = { - "predictor_id": str(predictor_id), - "mode": mode.value, - "include_ingredient_fraction_constraints": include_ingredient_fraction_constraints, - "include_label_fraction_constraints": include_label_fraction_constraints, - "include_label_count_constraints": include_label_count_constraints, - "include_parameter_constraints": include_parameter_constraints - } - if predictor_version: - payload["predictor_version"] = predictor_version + settings = DesignSpaceSettings( + predictor_id=predictor_id, + predictor_version=predictor_version, + mode=mode, + include_ingredient_fraction_constraints=include_ingredient_fraction_constraints, + include_label_fraction_constraints=include_label_fraction_constraints, + include_label_count_constraints=include_label_count_constraints, + include_parameter_constraints=include_parameter_constraints + ) - data = self.session.post_resource(path, json=payload, version=self._api_version) - return self.build(DesignSpace.wrap_instance(data["instance"])) + data = self.session.post_resource(path, json=settings.dump(), version=self._api_version) + ds = self.build(DesignSpace.wrap_instance(data["instance"])) + ds._settings = settings + return ds def convert_to_hierarchical( self, diff --git a/tests/resources/test_design_space.py b/tests/resources/test_design_space.py index e52b05fd0..972b8b207 100644 --- a/tests/resources/test_design_space.py +++ b/tests/resources/test_design_space.py @@ -7,8 +7,9 @@ from citrine.exceptions import ModuleRegistrationFailedException, NotFound from citrine.informatics.descriptors import RealDescriptor, FormulationKey -from citrine.informatics.design_spaces import EnumeratedDesignSpace, DesignSpace, ProductDesignSpace -from citrine.resources.design_space import DesignSpaceCollection, DefaultDesignSpaceMode +from citrine.informatics.design_spaces import DefaultDesignSpaceMode, DesignSpace, \ + DesignSpaceSettings, EnumeratedDesignSpace, HierarchicalDesignSpace, ProductDesignSpace +from citrine.resources.design_space import DesignSpaceCollection from citrine.resources.status_detail import StatusDetail, StatusLevelEnum from tests.utils.session import FakeCall, FakeSession @@ -192,16 +193,15 @@ def test_create_default(predictor_version, valid_product_design_space): session=session ) - expected_payload = { - "predictor_id": str(predictor_id), - "include_ingredient_fraction_constraints": False, - "include_label_fraction_constraints": False, - "include_label_count_constraints": False, - "include_parameter_constraints": False, - "mode": DefaultDesignSpaceMode.ATTRIBUTE.value, - } - if predictor_version is not None: - expected_payload["predictor_version"] = predictor_version + expected_payload = DesignSpaceSettings( + predictor_id=predictor_id, + predictor_version=predictor_version, + include_ingredient_fraction_constraints=False, + include_label_fraction_constraints=False, + include_label_count_constraints=False, + include_parameter_constraints=False, + mode=DefaultDesignSpaceMode.ATTRIBUTE + ).dump() expected_call = FakeCall( method='POST', @@ -215,7 +215,51 @@ def test_create_default(predictor_version, valid_product_design_space): assert session.num_calls == 1 assert session.last_call == expected_call - assert default_design_space.dump() == valid_product_design_space.dump() + expected_response = {**valid_product_design_space.dump(), "settings": expected_payload} + assert default_design_space.dump() == expected_response + + +@pytest.mark.parametrize("predictor_version", (2, "1", "latest", None)) +def test_create_default_hierarchical(predictor_version, valid_hierarchical_design_space_data): + valid_hierarchical_design_space = HierarchicalDesignSpace.build(valid_hierarchical_design_space_data) + + session = FakeSession() + session.set_response(valid_hierarchical_design_space.dump()) + + predictor_id = uuid.uuid4() + collection = DesignSpaceCollection( + project_id=uuid.uuid4(), + session=session + ) + + expected_payload = DesignSpaceSettings( + predictor_id=predictor_id, + predictor_version=predictor_version, + include_ingredient_fraction_constraints=False, + include_label_fraction_constraints=False, + include_label_count_constraints=False, + include_parameter_constraints=False, + mode=DefaultDesignSpaceMode.HIERARCHICAL + ).dump() + + expected_call = FakeCall( + method='POST', + path=f"projects/{collection.project_id}/design-spaces/default", + json=expected_payload, + version="v3" + ) + + default_design_space = collection.create_default( + predictor_id=predictor_id, + predictor_version=predictor_version, + mode=DefaultDesignSpaceMode.HIERARCHICAL + ) + + assert session.num_calls == 1 + assert session.last_call == expected_call + + expected_response = {**valid_hierarchical_design_space.dump(), "settings": expected_payload} + assert default_design_space.dump() == expected_response @pytest.mark.parametrize("ingredient_fractions", (True, False)) @@ -233,19 +277,21 @@ def test_create_default_with_config(valid_product_design_space, ingredient_fract project_id=uuid.uuid4(), session=session ) + + expected_payload = DesignSpaceSettings( + predictor_id=predictor_id, + predictor_version=predictor_version, + include_ingredient_fraction_constraints=ingredient_fractions, + include_label_fraction_constraints=label_fractions, + include_label_count_constraints=label_count, + include_parameter_constraints=parameters, + mode=DefaultDesignSpaceMode.ATTRIBUTE + ).dump() expected_call = FakeCall( method='POST', path=f"projects/{collection.project_id}/design-spaces/default", - json={ - "mode": DefaultDesignSpaceMode.ATTRIBUTE.value, - "predictor_id": str(predictor_id), - "predictor_version": predictor_version, - "include_ingredient_fraction_constraints": ingredient_fractions, - "include_label_fraction_constraints": label_fractions, - "include_label_count_constraints": label_count, - "include_parameter_constraints": parameters - }, + json=expected_payload, version="v3" ) @@ -261,7 +307,8 @@ def test_create_default_with_config(valid_product_design_space, ingredient_fract assert session.num_calls == 1 assert session.last_call == expected_call - assert default_design_space.dump() == valid_product_design_space.dump() + expected_response = {**valid_product_design_space.dump(), "settings": expected_payload} + assert default_design_space.dump() == expected_response def test_list_design_spaces(valid_formulation_design_space_data, valid_enumerated_design_space_data): @@ -408,3 +455,82 @@ def test_delete_not_supported(): dsc = DesignSpaceCollection(uuid.uuid4(), FakeSession()) with pytest.raises(NotImplementedError): dsc.delete(uuid.uuid4()) + + +def test_carrying_settings_from_create_default(valid_product_design_space): + predictor_id = uuid.uuid4() + predictor_version = 4 + + session = FakeSession() + + ds_resp = _ds_to_response(valid_product_design_space) + session.set_responses(ds_resp["data"], deepcopy(ds_resp), deepcopy(ds_resp)) + + collection = DesignSpaceCollection(project_id=uuid.uuid4(), session=session) + + default_design_space = collection.create_default( + predictor_id=predictor_id, + predictor_version=predictor_version, + include_label_count_constraints=True + ) + registered = collection.register(default_design_space) + + expected_settings = DesignSpaceSettings( + predictor_id=predictor_id, + predictor_version=predictor_version, + include_ingredient_fraction_constraints=False, + include_label_fraction_constraints=False, + include_label_count_constraints=True, + include_parameter_constraints=False, + mode=DefaultDesignSpaceMode.ATTRIBUTE + ) + expected_payload = {**valid_product_design_space.dump(), "settings": expected_settings.dump()} + + expected_call = FakeCall( + method='POST', + path=f"projects/{collection.project_id}/design-spaces", + json=expected_payload, + version="v3" + ) + + assert session.num_calls == 3 + assert session.calls[1] == expected_call + + +def test_carrying_settings_from_get(valid_product_design_space): + predictor_id = uuid.uuid4() + predictor_version = 4 + + session = FakeSession() + + expected_settings = DesignSpaceSettings( + predictor_id=predictor_id, + predictor_version=predictor_version, + exclude_intermediates=True, + include_ingredient_fraction_constraints=False, + include_label_fraction_constraints=False, + include_label_count_constraints=False, + include_parameter_constraints=True, + mode=DefaultDesignSpaceMode.ATTRIBUTE + ) + + ds_resp = _ds_to_response(valid_product_design_space) + ds_resp["metadata"]["settings"] = expected_settings.dump() + session.set_responses(deepcopy(ds_resp), deepcopy(ds_resp), deepcopy(ds_resp)) + + collection = DesignSpaceCollection(project_id=uuid.uuid4(), session=session) + + retrieved = collection.get(uuid.uuid4()) + registered = collection.register(retrieved) + + expected_payload = {**valid_product_design_space.dump(), "settings": expected_settings.dump()} + + expected_call = FakeCall( + method='POST', + path=f"projects/{collection.project_id}/design-spaces", + json=expected_payload, + version="v3" + ) + + assert session.num_calls == 3 + assert session.calls[1] == expected_call