Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/citrine/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "3.4.8"
__version__ = "3.5.0"
22 changes: 19 additions & 3 deletions src/citrine/informatics/predictors/auto_ml_predictor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import List, Optional, Set

from deprecation import deprecated
from gemd.enumeration.base_enumeration import BaseEnumeration

from citrine._rest.resource import Resource
Expand Down Expand Up @@ -52,7 +53,7 @@ class AutoMLPredictor(Resource["AutoMLPredictor"], PredictorNode):
estimators: Optional[Set[AutoMLEstimator]]
Set of estimators to consider during during AutoML model selection.
If None is provided, defaults to AutoMLEstimator.RANDOM_FOREST.
training_data: Optional[List[DataSource]]
training_data: Optional[List[DataSource]] (deprecated)
Sources of training data. Each can be either a CSV or an GEM Table. Candidates from
multiple data sources will be combined into a flattened list and de-duplicated by uid and
identifiers. De-duplication is performed if a uid or identifier is shared between two or
Expand All @@ -69,7 +70,7 @@ class AutoMLPredictor(Resource["AutoMLPredictor"], PredictorNode):
'estimators',
default={AutoMLEstimator.RANDOM_FOREST}
)
training_data = _properties.List(
_training_data = _properties.List(
_properties.Object(DataSource),
'training_data',
default=[]
Expand All @@ -90,7 +91,22 @@ def __init__(self,
self.inputs: List[Descriptor] = inputs
self.estimators: Set[AutoMLEstimator] = estimators or {AutoMLEstimator.RANDOM_FOREST}
self.outputs = outputs
self.training_data: List[DataSource] = training_data or []
# self.training_data: List[DataSource] = training_data or []
if training_data:
self.training_data: List[DataSource] = training_data

@property
@deprecated(deprecated_in="3.5.0", removed_in="4.0.0",
details="Training data must be accessed through the top-level GraphPredictor.'")
def training_data(self):
"""[DEPRECATED] Retrieve training data associated with this node."""
return self._training_data

@training_data.setter
@deprecated(deprecated_in="3.5.0", removed_in="4.0.0",
details="Training data should only be added to the top-level GraphPredictor.'")
def training_data(self, value):
self._training_data = value

def __str__(self):
return '<AutoMLPredictor {!r}>'.format(self.name)
25 changes: 21 additions & 4 deletions src/citrine/informatics/predictors/mean_property_predictor.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from typing import List, Optional, Mapping, Union
from typing import List, Mapping, Optional, Union

from deprecation import deprecated

from citrine._rest.resource import Resource
from citrine._serialization import properties as _properties
from citrine.informatics.data_sources import DataSource
from citrine.informatics.descriptors import (
FormulationDescriptor, RealDescriptor, CategoricalDescriptor
CategoricalDescriptor, FormulationDescriptor, RealDescriptor
)
from citrine.informatics.predictors import PredictorNode

Expand Down Expand Up @@ -79,7 +81,7 @@ class MeanPropertyPredictor(Resource["MeanPropertyPredictor"], PredictorNode):
),
'default_properties'
)
training_data = _properties.List(
_training_data = _properties.List(
_properties.Object(DataSource), 'training_data', default=[]
)

Expand All @@ -104,7 +106,22 @@ def __init__(self,
self.impute_properties: bool = impute_properties
self.label: Optional[str] = label
self.default_properties: Optional[Mapping[str, Union[str, float]]] = default_properties
self.training_data: List[DataSource] = training_data or []
# self.training_data: List[DataSource] = training_data or []
if training_data:
self.training_data: List[DataSource] = training_data

def __str__(self):
return '<MeanPropertyPredictor {!r}>'.format(self.name)

@property
@deprecated(deprecated_in="3.5.0", removed_in="4.0.0",
details="Training data must be accessed through the top-level GraphPredictor.'")
def training_data(self):
"""[DEPRECATED] Retrieve training data associated with this node."""
return self._training_data

@training_data.setter
@deprecated(deprecated_in="3.5.0", removed_in="4.0.0",
details="Training data should only be added to the top-level GraphPredictor.'")
def training_data(self, value):
self._training_data = value
20 changes: 18 additions & 2 deletions src/citrine/informatics/predictors/simple_mixture_predictor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import List, Optional

from deprecation import deprecated

from citrine._rest.resource import Resource
from citrine._serialization import properties
from citrine.informatics.data_sources import DataSource
Expand Down Expand Up @@ -28,7 +30,7 @@ class SimpleMixturePredictor(Resource["SimpleMixturePredictor"], PredictorNode):

"""

training_data = properties.List(properties.Object(DataSource), 'training_data', default=[])
_training_data = properties.List(properties.Object(DataSource), 'training_data', default=[])

typ = properties.String('type', default='SimpleMixture', deserializable=False)

Expand All @@ -39,7 +41,8 @@ def __init__(self,
training_data: Optional[List[DataSource]] = None):
self.name: str = name
self.description: str = description
self.training_data: List[DataSource] = training_data or []
if training_data:
self.training_data: List[DataSource] = training_data

def __str__(self):
return '<SimpleMixturePredictor {!r}>'.format(self.name)
Expand All @@ -53,3 +56,16 @@ def input_descriptor(self) -> FormulationDescriptor:
def output_descriptor(self) -> FormulationDescriptor:
"""The output formulation descriptor with key 'Flat Formulation'."""
return FormulationDescriptor.flat()

@property
@deprecated(deprecated_in="3.5.0", removed_in="4.0.0",
details="Training data must be accessed through the top-level GraphPredictor.'")
def training_data(self):
"""[DEPRECATED] Retrieve training data associated with this node."""
return self._training_data

@training_data.setter
@deprecated(deprecated_in="3.5.0", removed_in="4.0.0",
details="Training data should only be added to the top-level GraphPredictor.'")
def training_data(self, value):
self._training_data = value
56 changes: 56 additions & 0 deletions tests/informatics/test_predictors.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,24 @@ def test_auto_ml_multiple_outputs(auto_ml_multiple_outputs):
assert built.dump()['outputs'] == [z.dump(), y.dump()]


def test_auto_ml_deprecated_training_data(auto_ml):
with pytest.deprecated_call():
pred = AutoMLPredictor(
name='AutoML Predictor',
description='Predicts z from inputs w and x',
inputs=auto_ml.inputs,
outputs=auto_ml.outputs,
training_data=[GemTableDataSource(table_id=uuid.uuid4(), table_version=1)]
)

new_training_data = [GemTableDataSource(table_id=uuid.uuid4(), table_version=2)]
with pytest.deprecated_call():
pred.training_data = new_training_data

with pytest.deprecated_call():
assert pred.training_data == new_training_data


def test_ing_to_formulation_initialization(ing_to_formulation_predictor):
"""Make sure the correct fields go to the correct places for an ingredients to formulation predictor."""
assert ing_to_formulation_predictor.name == 'Ingredients to formulation predictor'
Expand Down Expand Up @@ -361,6 +379,28 @@ def test_mean_property_round_robin(mean_property_predictor):
assert len(cat_props) == 1


def test_mean_property_training_data_deprecated(mean_property_predictor):
with pytest.deprecated_call():
pred = MeanPropertyPredictor(
name='Mean property predictor',
description='Computes mean ingredient properties',
input_descriptor=mean_property_predictor.input_descriptor,
properties=mean_property_predictor.properties,
p=2.5,
impute_properties=True,
default_properties=mean_property_predictor.default_properties,
label=mean_property_predictor.label,
training_data=[GemTableDataSource(table_id=uuid.uuid4(), table_version=1)]
)

new_training_data = [GemTableDataSource(table_id=uuid.uuid4(), table_version=2)]
with pytest.deprecated_call():
pred.training_data = new_training_data

with pytest.deprecated_call():
assert pred.training_data == new_training_data


def test_label_fractions_property_initialization(label_fractions_predictor):
"""Make sure the correct fields go to the correct places for a label fraction predictor."""
assert label_fractions_predictor.name == 'Label fractions predictor'
Expand All @@ -379,6 +419,22 @@ def test_simple_mixture_predictor_initialization(simple_mixture_predictor):
assert str(simple_mixture_predictor) == expected_str


def test_simplex_mixture_training_data_deprecated():
with pytest.deprecated_call():
pred = SimpleMixturePredictor(
name='Simple mixture predictor',
description='Computes mean ingredient properties',
training_data=[GemTableDataSource(table_id=uuid.uuid4(), table_version=1)]
)

new_training_data = [GemTableDataSource(table_id=uuid.uuid4(), table_version=2)]
with pytest.deprecated_call():
pred.training_data = new_training_data

with pytest.deprecated_call():
assert pred.training_data == new_training_data


def test_ingredient_fractions_property_initialization(ingredient_fractions_predictor):
"""Make sure the correct fields go to the correct places for an ingredient fractions predictor."""
assert ingredient_fractions_predictor.name == 'Ingredient fractions predictor'
Expand Down
6 changes: 4 additions & 2 deletions tests/serialization/test_predictors.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ def test_auto_ml_deserialization(valid_auto_ml_predictor_data):
assert predictor.inputs[0] == RealDescriptor("x", lower_bound=0, upper_bound=100, units="")
assert len(predictor.outputs) == 1
assert predictor.outputs[0] == RealDescriptor("z", lower_bound=0, upper_bound=100, units="")
assert len(predictor.training_data) == 0
with pytest.deprecated_call():
assert len(predictor.training_data) == 0


def test_polymorphic_auto_ml_deserialization(valid_auto_ml_predictor_data):
Expand All @@ -31,7 +32,8 @@ def test_polymorphic_auto_ml_deserialization(valid_auto_ml_predictor_data):
assert predictor.inputs[0] == RealDescriptor("x", lower_bound=0, upper_bound=100, units="")
assert len(predictor.outputs) == 1
assert predictor.outputs[0] == RealDescriptor("z", lower_bound=0, upper_bound=100, units="")
assert len(predictor.training_data) == 0
with pytest.deprecated_call():
assert len(predictor.training_data) == 0


def test_legacy_serialization(valid_auto_ml_predictor_data):
Expand Down