diff --git a/docs/source/getting_started/ai_modules.rst b/docs/source/getting_started/ai_modules.rst index a410592e2..db48a9e10 100644 --- a/docs/source/getting_started/ai_modules.rst +++ b/docs/source/getting_started/ai_modules.rst @@ -38,7 +38,7 @@ Design Workflow A Design Workflow combines a Design Space to define the materials of interest and a Predictor to predict material properties. They also include a :doc:`Score <../workflows/scores>` which codifies goals of the project. -Predictor Evaluation Workflow -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Predictor Evaluation +^^^^^^^^^^^^^^^^^^^^ -:doc:`Predictor Evaluation Workflows <../workflows/predictor_evaluation_workflows>` analyze the quality of a Predictor. +:doc:`Predictor Evaluations <../workflows/predictor_evaluation_workflows>` analyze the quality of a Predictor. diff --git a/docs/source/getting_started/basic_functionality.rst b/docs/source/getting_started/basic_functionality.rst index 074b47ad2..d56a8d4a8 100644 --- a/docs/source/getting_started/basic_functionality.rst +++ b/docs/source/getting_started/basic_functionality.rst @@ -49,14 +49,12 @@ It is often useful to know when a resource has completed validating, especially sintering_model = sintering_project.predictors.register(sintering_model) wait_while_validating(collection=sintering_project.predictors, module=sintering_model) -Similarly, the ``wait_while_executing`` function will wait for a design or performance evaluation workflow to complete executing. +Similarly, the ``wait_while_executing`` function will wait for a design or predictor evaluation to complete executing. .. code-block:: python - pew_workflow = sintering_project.predictor_evaluation_workflows.register(pew_workflow) - pew_workflow = wait_while_validating(collection=sintering_project.predictor_evaluation_workflows, module=pew_workflow) - pew_ex = pew_workflow.trigger(sintering_model) - wait_while_executing(collection=sintering_project.predictor_evaluation_executions, execution=pew_ex, print_status_info=True) + predictor_evaluation = project.predictor_evaluations.trigger_default(predictor_id=sintering_model.uid) + wait_while_executing(collection=sintering_project.predictor_evaluations, execution=predictor_evaluation, print_status_info=True) Checking Status --------------- diff --git a/docs/source/workflows/getting_started.rst b/docs/source/workflows/getting_started.rst index 3294151e8..22a226778 100644 --- a/docs/source/workflows/getting_started.rst +++ b/docs/source/workflows/getting_started.rst @@ -11,9 +11,8 @@ These capabilities include generating candidates for Sequential Learning, identi Workflows Overview ------------------ -Currently, there are two workflows on the AI Engine: the :doc:`DesignWorkflow ` and the :doc:`PredictorEvaluationWorkflow `. -Workflows employ reusable modules in order to execute. -There are three different types of modules, and these are discussed in greater detail below. +Currently, there are two workflows on the AI Engine: the :doc:`DesignWorkflow ` and the :doc:`PredictorEvaluation `. +There are two different types of modules, and these are discussed in greater detail below. Design Workflow *************** @@ -38,11 +37,11 @@ Branches A ``Branch`` is a named container which can contain any number of design workflows, and is purely a tool for organization. If you do not see branches in the Citrine Platform, you do not need to change how you work with design workflows. They will contain an additional field ``branch_id``, which you can ignore. -Predictor Evaluation Workflow -***************************** +Predictor Evaluation +******************** -The :doc:`PredictorEvaluationWorkflow ` is used to analyze a :doc:`Predictor `. -This workflow helps users understand how well their predictor module works with their data: in essence, it describes the trustworthiness of their model. +The :doc:`PredictorEvaluation ` is used to analyze a :doc:`Predictor `. +They helps users understand how well their predictor module works with their data: in essence, it describes the trustworthiness of their model. These outcomes are captured in a series of response metrics. Modules Overview @@ -80,17 +79,3 @@ Validation status can be one of the following states: - **Error:** Validation did not complete. An error was raised during the validation process that prevented an invalid or ready status to be determined. Validation of a workflow and all constituent modules must complete with ready status before the workflow can be executed. - -Experimental functionality -************************** - -Both modules and workflows can be used to access experimental functionality on the platform. -In some cases, the module or workflow type itself may be experimental. -In other cases, whether a module or workflow represents experimental functionality may depend on the specific configuration of the module or workflow. -For example, a module might have an experimental option that is turned off by default. -Another example could be a workflow that contains an experimental module. -Because the experimental status of a module or workflow may not be known at registration time, it is computed as part -of the validation process and then returned via two fields: - -- `experimental` is a Boolean field that is true when the module or workflow is experimental -- `experimental_reasons` is a list of strings that describe what about the module or workflow makes it experimental diff --git a/docs/source/workflows/predictor_evaluation_workflows.rst b/docs/source/workflows/predictor_evaluation_workflows.rst index a22116813..ca29fff45 100644 --- a/docs/source/workflows/predictor_evaluation_workflows.rst +++ b/docs/source/workflows/predictor_evaluation_workflows.rst @@ -1,15 +1,15 @@ -Predictor Evaluation Workflows -============================== +Predictor Evaluations +===================== -A :class:`~citrine.informatics.workflows.predictor_evaluation_workflow.PredictorEvaluationWorkflow` evaluates the performance of a :doc:`Predictor `. -Each workflow is composed of one or more :class:`PredictorEvaluators `. +A :class:`~citrine.informatics.executions.predictor_evaluation.PredictorEvaluation` evaluates the performance of a :doc:`Predictor `. +Each evaluation utilizes one or more :class:`PredictorEvaluators `. Predictor evaluators -------------------- A predictor evaluator defines a method to evaluate a predictor and any relevant configuration, e.g., k-fold cross-validation evaluation that specifies 3 folds. Minimally, each predictor evaluator specifies a name, a set of predictor responses to evaluate and a set of metrics to compute for each response. -Evaluator names must be unique within a single workflow (more on that `below <#execution-and-results>`__). +Evaluator names must be unique within a single evaluation (more on that `below <#execution-and-results>`__). Responses are specified as a set of strings, where each string corresponds to a descriptor key of a predictor output. Metrics are specified as a set of :class:`PredictorEvaluationMetrics `. The evaluator will only compute the subset of metrics valid for each response, so the top-level metrics defined by an evaluator should contain the union of all metrics computed across all responses. @@ -102,22 +102,21 @@ For categorical responses, performance metrics include the area under the receiv Execution and results --------------------- -Triggering a Predictor Evaluation Workflow produces a :class:`~citrine.resources.predictor_evaluation_execution.PredictorEvaluationExecution`. -This execution allows you to track the progress using its ``status`` and ``status_info`` properties. -The ``status`` can be one of ``INPROGRESS``, ``READY``, or ``FAILED``. -Information about the execution status, e.g., warnings or reasons for failure, can be accessed via ``status_info``. +Once triggered, you can track the evaluation's progress using its ``status`` and ``status_detail`` properties. +The ``status`` can be one of ``INPROGRESS``, ``SUCCEEDED``, or ``FAILED``. +Information about the execution status, e.g., warnings or reasons for failure, can be accessed via ``status_detail``. -When the ``status`` is ``READY``, results for each evaluation defined as part of the workflow can be accessed using the ``results`` method: +When the ``status`` is ``SUCCEEDED``, results for each evaluator defined as part of the evaluation can be accessed using the ``results`` method: .. code:: python - results = execution.results('evaluator_name') + results = evaluation.results('evaluator_name') -or by indexing into the execution object directly: +or by indexing into the evaluation object directly: .. code:: python - results = execution['evaluator_name'] + results = evaluation['evaluator_name'] Both methods return a :class:`~citrine.informatics.predictor_evaluation_result.PredictorEvaluationResult`. @@ -153,7 +152,7 @@ Each data point defines properties ``uuid``, ``identifiers``, ``trial``, ``fold` Example ------- -The following demonstrates how to create a :class:`~citrine.informatics.predictor_evaluator.CrossValidationEvaluator`, add it to a :class:`~citrine.informatics.workflows.predictor_evaluation_workflow.PredictorEvaluationWorkflow`, and use it to evaluate a :class:`~citrine.informatics.predictors.predictor.Predictor`. +The following demonstrates how to create a :class:`~citrine.informatics.predictor_evaluator.CrossValidationEvaluator` and use it to evaluate a :class:`~citrine.informatics.predictors.predictor.Predictor`. The predictor we'll evaluate is defined below: @@ -215,36 +214,19 @@ In this example we'll create a cross-validation evaluator for the response ``y`` metrics={RMSE(), PVA()} ) -Then add the evaluator to a :class:`~citrine.informatics.workflows.predictor_evaluation_workflow.PredictorEvaluationWorkflow`, register it with your project, and wait for validation to finish: +Then, trigger an evaluation and wait for the results to be ready: .. code:: python - from citrine.informatics.workflows import PredictorEvaluationWorkflow - - workflow = PredictorEvaluationWorkflow( - name='workflow that evaluates y', - evaluators=[evaluator] - ) - - workflow = project.predictor_evaluation_workflows.register(workflow) - wait_while_validating(collection=project.predictor_evaluation_workflows, module=workflow) - -Trigger the workflow against a predictor to start an execution. -Then wait for the results to be ready: - -.. code:: python - - from citrine.jobs.waiting import wait_while_executing - - execution = workflow.executions.trigger(predictor.uid, predictor_version=predictor.version) - wait_while_executing(collection=project.predictor_evaluation_executions, execution=execution, print_status_info=True) + evaluation = project.predictor_evaluations.trigger(evaluators=[evaluator], predictor_id=predictor.uid) + wait_while_executing(collection=project.predictor_evaluations, execution=evaluation, print_status_info=True) Finally, load the results and inspect the metrics and their computed values: .. code:: python # load the results computed by the CV evaluator defined above - cv_results = execution[evaluator.name] + cv_results = evaluation[evaluator.name] # load results for y y_results = cv_results['y'] @@ -280,18 +262,17 @@ Finally, load the results and inspect the metrics and their computed values: Archive and restore ------------------- -Both :class:`PredictorEvaluationWorkflows ` and :class:`PredictorEvaluationExecutions ` can be archived and restored. -To archive a workflow: +:class:`PredictorEvaluation ` can be archived and restored. .. code:: python - project.predictor_evaluation_workflows.archive(workflow.uid) + project.predictor_evaluation.archive(evaluation.uid) -and to archive all executions associated with a predictor evaluation workflow: +and to archive all evaluations associated with a predictor: .. code:: python - for execution in workflow.executions.list(): - project.predictor_evaluation_executions.archive(execution.uid) + for evaluation in project.predictor_evaluations.list(predictor_id=predictor.uid): + project.predictor_evaluation.archive(evaluation.uid) -To restore a workflow or execution, simply replace ``archive`` with ``restore`` in the code above. +To restore an evaluation, simply replace ``archive`` with ``restore`` in the code above. diff --git a/docs/source/workflows/predictors.rst b/docs/source/workflows/predictors.rst index 258d833f5..8b03129e0 100644 --- a/docs/source/workflows/predictors.rst +++ b/docs/source/workflows/predictors.rst @@ -694,7 +694,7 @@ Predictor reports A :doc:`predictor report ` describes a machine-learned model, for example its settings and what features are important to the model. It does not include predictor evaluation metrics. -To learn more about predictor evaluation metrics, please see :doc:`PredictorEvaluationWorkflow `. +To learn more about predictor evaluation metrics, please see :doc:`PredictorEvaluation `. Training data ------------- diff --git a/src/citrine/__version__.py b/src/citrine/__version__.py index ce13706c4..3e6d8c7b7 100644 --- a/src/citrine/__version__.py +++ b/src/citrine/__version__.py @@ -1 +1 @@ -__version__ = "3.23.1" +__version__ = "3.24.0" diff --git a/src/citrine/_rest/engine_resource.py b/src/citrine/_rest/engine_resource.py index e7f0e9520..9abd9f568 100644 --- a/src/citrine/_rest/engine_resource.py +++ b/src/citrine/_rest/engine_resource.py @@ -42,13 +42,14 @@ def is_archived(self): return self.archived_by is not None def _post_dump(self, data: dict) -> dict: - # Only the data portion of an entity is sent to the server. - data = data["data"] - - if "instance" in data: - # Currently, name and description exists on both the data envelope and the config. - data["instance"]["name"] = data["name"] - data["instance"]["description"] = data["description"] + if data: + # Only the data portion of an entity is sent to the server. + data = data["data"] + + if "instance" in data: + # Currently, name and description exists on both the data envelope and the config. + data["instance"]["name"] = data["name"] + data["instance"]["description"] = data["description"] return super()._post_dump(data) diff --git a/src/citrine/informatics/executions/predictor_evaluation.py b/src/citrine/informatics/executions/predictor_evaluation.py new file mode 100644 index 000000000..ae8270337 --- /dev/null +++ b/src/citrine/informatics/executions/predictor_evaluation.py @@ -0,0 +1,98 @@ +from functools import lru_cache +from typing import List, Optional, Union +from uuid import UUID + +from citrine.informatics.predictor_evaluation_result import PredictorEvaluationResult +from citrine.informatics.predictor_evaluator import PredictorEvaluator +from citrine.resources.status_detail import StatusDetail +from citrine._rest.engine_resource import EngineResourceWithoutStatus +from citrine._rest.resource import PredictorRef +from citrine._serialization import properties +from citrine._serialization.serializable import Serializable +from citrine._utils.functions import format_escaped_url + + +class PredictorEvaluatorsResponse(Serializable['EvaluatorsPayload']): + """Container object for a default predictor evaluator response.""" + + evaluators = properties.List(properties.Object(PredictorEvaluator), "evaluators") + + def __init__(self, evaluators: List[PredictorEvaluator]): + self.evaluators = evaluators + + +class PredictorEvaluationRequest(Serializable['EvaluatorsPayload']): + """Container object for a predictor evaluation request.""" + + predictor = properties.Object(PredictorRef, "predictor") + evaluators = properties.List(properties.Object(PredictorEvaluator), "evaluators") + + def __init__(self, + *, + evaluators: List[PredictorEvaluator], + predictor_id: Union[UUID, str], + predictor_version: Optional[Union[int, str]] = None): + self.evaluators = evaluators + self.predictor = PredictorRef(predictor_id, predictor_version) + + +class PredictorEvaluation(EngineResourceWithoutStatus['PredictorEvaluation']): + """The evaluation of a predictor's performance.""" + + uid: UUID = properties.UUID('id', serializable=False) + """:UUID: Unique identifier of the evaluation""" + evaluators = properties.List(properties.Object(PredictorEvaluator), "data.evaluators", + serializable=False) + """:List{PredictorEvaluator]:the predictor evaluators that were executed. These are used + when calling the ``results()`` method.""" + predictor_id = properties.UUID('metadata.predictor_id', serializable=False) + """:UUID:""" + predictor_version = properties.Integer('metadata.predictor_version', serializable=False) + status = properties.String('metadata.status.major', serializable=False) + """:str: short description of the evaluation's status""" + status_description = properties.String('metadata.status.minor', serializable=False) + """:str: more detailed description of the evaluation's status""" + status_detail = properties.List(properties.Object(StatusDetail), 'metadata.status.detail', + default=[], serializable=False) + """:List[StatusDetail]: a list of structured status info, containing the message and level""" + + def _path(self): + return format_escaped_url( + '/projects/{project_id}/predictor-evaluations/{evaluation_id}', + project_id=str(self.project_id), + evaluation_id=str(self.uid) + ) + + @lru_cache() + def results(self, evaluator_name: str) -> PredictorEvaluationResult: + """ + Get a specific evaluation result by the name of the evaluator that produced it. + + Parameters + ---------- + evaluator_name: str + Name of the evaluator for which to get the results + + Returns + ------- + PredictorEvaluationResult + The evaluation result from the evaluator with the given name + + """ + params = {"evaluator_name": evaluator_name} + resource = self._session.get_resource(self._path() + "/results", params=params) + return PredictorEvaluationResult.build(resource) + + @property + def evaluator_names(self): + """Names of the predictor evaluators. Used when calling the ``results()`` method.""" + return list(iter(self)) + + def __getitem__(self, item): + if isinstance(item, str): + return self.results(item) + else: + raise TypeError("Results are accessed by string names") + + def __iter__(self): + return iter(e.name for e in self.evaluators) diff --git a/src/citrine/informatics/executions/predictor_evaluation_execution.py b/src/citrine/informatics/executions/predictor_evaluation_execution.py index 4e5a519f5..42a2cdb4e 100644 --- a/src/citrine/informatics/executions/predictor_evaluation_execution.py +++ b/src/citrine/informatics/executions/predictor_evaluation_execution.py @@ -8,7 +8,7 @@ class PredictorEvaluationExecution(Resource['PredictorEvaluationExecution'], Execution): - """The execution of a PredictorEvaluationWorkflow. + """[DEPRECATED] The execution of a PredictorEvaluationWorkflow. Possible statuses are INPROGRESS, SUCCEEDED, and FAILED. Predictor evaluation executions also have a ``status_description`` field with more information. diff --git a/src/citrine/informatics/predictor_evaluator.py b/src/citrine/informatics/predictor_evaluator.py index 523034aee..499f6407b 100644 --- a/src/citrine/informatics/predictor_evaluator.py +++ b/src/citrine/informatics/predictor_evaluator.py @@ -112,7 +112,8 @@ class CrossValidationEvaluator(Serializable["CrossValidationEvaluator"], Predict typ = properties.String("type", default="CrossValidationEvaluator", deserializable=False) def __init__(self, - name: str, *, + name: str, + *, description: str = "", responses: Set[str], n_folds: int = 5, diff --git a/src/citrine/informatics/workflows/predictor_evaluation_workflow.py b/src/citrine/informatics/workflows/predictor_evaluation_workflow.py index 54aaaeb06..c6af228e4 100644 --- a/src/citrine/informatics/workflows/predictor_evaluation_workflow.py +++ b/src/citrine/informatics/workflows/predictor_evaluation_workflow.py @@ -12,7 +12,7 @@ class PredictorEvaluationWorkflow(Resource['PredictorEvaluationWorkflow'], Workflow, AIResourceMetadata): - """A workflow that evaluations a predictor. + """[DEPRECATED] A workflow that evaluates a predictor. Parameters ---------- diff --git a/src/citrine/resources/design_space.py b/src/citrine/resources/design_space.py index add28612e..3a2112d08 100644 --- a/src/citrine/resources/design_space.py +++ b/src/citrine/resources/design_space.py @@ -129,15 +129,15 @@ def _list_base(self, *, per_page: int = 100, archived: Optional[bool] = None): per_page=per_page) def list_all(self, *, per_page: int = 20) -> Iterable[DesignSpace]: - """List the most recent version of all design spaces.""" + """List all design spaces.""" return self._list_base(per_page=per_page) def list(self, *, per_page: int = 20) -> Iterable[DesignSpace]: - """List the most recent version of all non-archived design spaces.""" + """List non-archived design spaces.""" return self._list_base(per_page=per_page, archived=False) def list_archived(self, *, per_page: int = 20) -> Iterable[DesignSpace]: - """List the most recent version of all archived predictors.""" + """List archived design spaces.""" return self._list_base(per_page=per_page, archived=True) def create_default(self, diff --git a/src/citrine/resources/predictor_evaluation.py b/src/citrine/resources/predictor_evaluation.py new file mode 100644 index 000000000..20156bd49 --- /dev/null +++ b/src/citrine/resources/predictor_evaluation.py @@ -0,0 +1,219 @@ +from functools import partial +from typing import Iterable, Iterator, List, Optional, Union +from uuid import UUID + +from citrine.informatics.executions.predictor_evaluation import PredictorEvaluation, \ + PredictorEvaluationRequest, PredictorEvaluatorsResponse +from citrine.informatics.predictor_evaluator import PredictorEvaluator +from citrine.informatics.predictors import GraphPredictor +from citrine.resources.predictor import LATEST_VER as LATEST_PRED_VER +from citrine._rest.collection import Collection +from citrine._rest.resource import PredictorRef +from citrine._session import Session + + +class PredictorEvaluationCollection(Collection[PredictorEvaluation]): + """Represents the collection of predictor evaluations. + + Parameters + ---------- + project_id: UUID + the UUID of the project + + """ + + _api_version = 'v1' + _path_template = '/projects/{project_id}/predictor-evaluations' + _individual_key = None + _resource = PredictorEvaluation + _collection_key = 'response' + + def __init__(self, project_id: UUID, session: Session): + self.project_id = project_id + self.session: Session = session + + def build(self, data: dict) -> PredictorEvaluation: + """Build an individual predictor evaluation.""" + evaluation = PredictorEvaluation.build(data) + evaluation._session = self.session + evaluation.project_id = self.project_id + return evaluation + + def _list_base(self, + *, + per_page: int = 100, + predictor_id: Optional[UUID] = None, + predictor_version: Optional[Union[int, str]] = None, + archived: Optional[bool] = None + ) -> Iterator[PredictorEvaluation]: + params = {"archived": archived} + if predictor_id is not None: + params["predictor_id"] = str(predictor_id) + if predictor_version is not None: + params["predictor_version"] = predictor_version + + fetcher = partial(self._fetch_page, additional_params=params) + return self._paginator.paginate(page_fetcher=fetcher, + collection_builder=self._build_collection_elements, + per_page=per_page) + + def list_all(self, + *, + per_page: int = 100, + predictor_id: Optional[UUID] = None, + predictor_version: Optional[Union[int, str]] = None + ) -> Iterable[PredictorEvaluation]: + """List all predictor evaluations.""" + return self._list_base(per_page=per_page, + predictor_id=predictor_id, + predictor_version=predictor_version) + + def list(self, + *, + per_page: int = 100, + predictor_id: Optional[UUID] = None, + predictor_version: Optional[Union[int, str]] = None + ) -> Iterable[PredictorEvaluation]: + """List non-archived predictor evaluations.""" + return self._list_base(per_page=per_page, + predictor_id=predictor_id, + predictor_version=predictor_version, + archived=False) + + def list_archived(self, + *, + per_page: int = 100, + predictor_id: Optional[UUID] = None, + predictor_version: Optional[Union[int, str]] = None + ) -> Iterable[PredictorEvaluation]: + """List archived predictor evaluations.""" + return self._list_base(per_page=per_page, + predictor_id=predictor_id, + predictor_version=predictor_version, + archived=True) + + def archive(self, uid: Union[UUID, str]): + """Archive an evaluation.""" + url = self._get_path(uid, action="archive") + result = self.session.put_resource(url, {}, version=self._api_version) + return self.build(result) + + def restore(self, uid: Union[UUID, str]): + """Restore an archived evaluation.""" + url = self._get_path(uid, action="restore") + result = self.session.put_resource(url, {}, version=self._api_version) + return self.build(result) + + def default_from_config(self, config: GraphPredictor) -> List[PredictorEvaluator]: + """Retrieve the default evaluators for an arbitrary (but valid) predictor config. + + See :func:`~citrine.resources.PredictorEvaluationCollection.default` for details + on the resulting evaluators. + """ + path = self._get_path(action="default-from-config") + payload = config.dump()["instance"] + result = self.session.post_resource(path, json=payload, version=self._api_version) + return PredictorEvaluatorsResponse.build(result).evaluators + + def default(self, + *, + predictor_id: Union[UUID, str], + predictor_version: Union[int, str] = LATEST_PRED_VER + ) -> List[PredictorEvaluator]: + """Retrieve the default evaluators for a stored predictor. + + The current default evaluators perform 5-fold, 3-trial cross-validation on all valid + predictor responses. Valid responses are those that are **not** produced by the + following predictors: + + * :class:`~citrine.informatics.predictors.generalized_mean_property_predictor.GeneralizedMeanPropertyPredictor` + * :class:`~citrine.informatics.predictors.mean_property_predictor.MeanPropertyPredictor` + * :class:`~citrine.informatics.predictors.ingredient_fractions_predictor.IngredientFractionsPredictor` + * :class:`~citrine.informatics.predictors.ingredients_to_simple_mixture_predictor.IngredientsToSimpleMixturePredictor` + * :class:`~citrine.informatics.predictors.ingredients_to_formulation_predictor.IngredientsToFormulationPredictor` + * :class:`~citrine.informatics.predictors.label_fractions_predictor.LabelFractionsPredictor` + * :class:`~citrine.informatics.predictors.molecular_structure_featurizer.MolecularStructureFeaturizer` + * :class:`~citrine.informatics.predictors.simple_mixture_predictor.SimpleMixturePredictor` + + Parameters + ---------- + predictor_id: UUID + Unique identifier of the predictor to evaluate + predictor_version: Option[Union[int, str]] + The version of the predictor to evaluate + + Returns + ------- + PredictorEvaluation + + """ # noqa: E501,W505 + path = self._get_path(action="default") + payload = PredictorRef(uid=predictor_id, version=predictor_version).dump() + result = self.session.post_resource(path, json=payload, version=self._api_version) + return PredictorEvaluatorsResponse.build(result).evaluators + + def trigger(self, + *, + predictor_id: Union[UUID, str], + predictor_version: Union[int, str] = LATEST_PRED_VER, + evaluators: List[PredictorEvaluator]) -> PredictorEvaluation: + """Evaluate a predictor using the provided evaluators. + + Parameters + ---------- + predictor_id: UUID + Unique identifier of the predictor to evaluate + predictor_version: Option[Union[int, str]] + The version of the predictor to evaluate. Defaults to the latest trained version. + evaluators: List[PredictorEvaluator] + The evaluators to use to measure predictor performance. + + Returns + ------- + PredictorEvaluation + + """ + path = self._get_path("trigger") + payload = PredictorEvaluationRequest(evaluators=evaluators, + predictor_id=predictor_id, + predictor_version=predictor_version).dump() + result = self.session.post_resource(path, payload, version=self._api_version) + return self.build(result) + + def trigger_default(self, + *, + predictor_id: Union[UUID, str], + predictor_version: Union[int, str] = LATEST_PRED_VER + ) -> PredictorEvaluation: + """Evaluate a predictor using the default evaluators. + + See :func:`~citrine.resources.PredictorCollection.default` for details on the evaluators. + + Parameters + ---------- + predictor_id: UUID + Unique identifier of the predictor to evaluate + predictor_version: Option[Union[int, str]] + The version of the predictor to evaluate + + Returns + ------- + PredictorEvaluation + + """ # noqa: E501,W505 + path = self._get_path("trigger-default") + payload = PredictorRef(uid=predictor_id, version=predictor_version).dump() + result = self.session.post_resource(path, json=payload, version=self._api_version) + return self.build(result) + + def register(self, model: PredictorEvaluation) -> PredictorEvaluation: + """Cannot register an evaluation.""" + raise NotImplementedError("Cannot register a PredictorEvaluation.") + + def update(self, model: PredictorEvaluation) -> PredictorEvaluation: + """Cannot update an evaluation.""" + raise NotImplementedError("Cannot update a PredictorEvaluation.") + + def delete(self, uid: Union[UUID, str]): + """Cannot delete an evaluation.""" + raise NotImplementedError("Cannot delete a PredictorEvaluation.") diff --git a/src/citrine/resources/predictor_evaluation_execution.py b/src/citrine/resources/predictor_evaluation_execution.py index 51bc91d5e..50b392708 100644 --- a/src/citrine/resources/predictor_evaluation_execution.py +++ b/src/citrine/resources/predictor_evaluation_execution.py @@ -1,4 +1,5 @@ """Resources that represent both individual and collections of predictor evaluation executions.""" +from deprecation import deprecated from functools import partial from typing import Optional, Union, Iterator from uuid import UUID @@ -19,6 +20,9 @@ class PredictorEvaluationExecutionCollection(Collection["PredictorEvaluationExec _collection_key = 'response' _resource = predictor_evaluation_execution.PredictorEvaluationExecution + @deprecated(deprecated_in="3.23.0", removed_in="4.0.0", + details="Predictor evaluation workflows are being eliminated in favor of directly" + "evaluating predictors. Please use Project.predictor_evaluations instead.") def __init__(self, project_id: UUID, session: Session, @@ -34,6 +38,8 @@ def build(self, data: dict) -> predictor_evaluation_execution.PredictorEvaluatio execution.project_id = self.project_id return execution + @deprecated(deprecated_in="3.23.0", removed_in="4.0.0", + details="Please use PredictorEvaluationCollection.trigger instead.") def trigger(self, predictor_id: UUID, *, @@ -71,18 +77,22 @@ def trigger(self, return self.build(data) + @deprecated(deprecated_in="3.23.0", removed_in="4.0.0") def register(self, model: predictor_evaluation_execution.PredictorEvaluationExecution ) -> predictor_evaluation_execution.PredictorEvaluationExecution: """Cannot register an execution.""" raise NotImplementedError("Cannot register a PredictorEvaluationExecution.") + @deprecated(deprecated_in="3.23.0", removed_in="4.0.0") def update(self, model: predictor_evaluation_execution.PredictorEvaluationExecution ) -> predictor_evaluation_execution.PredictorEvaluationExecution: """Cannot update an execution.""" raise NotImplementedError("Cannot update a PredictorEvaluationExecution.") + @deprecated(deprecated_in="3.23.0", removed_in="4.0.0", + details="Please use PredictorEvaluation.archive") def archive(self, uid: Union[UUID, str]): """Archive a predictor evaluation execution. @@ -94,6 +104,8 @@ def archive(self, uid: Union[UUID, str]): """ self._put_resource_ref('archive', uid) + @deprecated(deprecated_in="3.23.0", removed_in="4.0.0", + details="Please use PredictorEvaluation.restore") def restore(self, uid: Union[UUID, str]): """Restore an archived predictor evaluation execution. @@ -105,6 +117,8 @@ def restore(self, uid: Union[UUID, str]): """ self._put_resource_ref('restore', uid) + @deprecated(deprecated_in="3.23.0", removed_in="4.0.0", + details="Please use PredictorEvaluation.list") def list(self, *, per_page: int = 100, @@ -144,7 +158,15 @@ def list(self, collection_builder=self._build_collection_elements, per_page=per_page) + @deprecated(deprecated_in="3.23.0", removed_in="4.0.0") def delete(self, uid: Union[UUID, str]) -> Response: """Predictor Evaluation Executions cannot be deleted; they can be archived instead.""" raise NotImplementedError( "Predictor Evaluation Executions cannot be deleted; they can be archived instead.") + + @deprecated(deprecated_in="3.23.0", removed_in="4.0.0", + details="Please use PredictorEvaluation.get") + def get(self, + uid: Union[UUID, str]) -> predictor_evaluation_execution.PredictorEvaluationExecution: + """Get a particular element of the collection.""" + return super().get(uid) diff --git a/src/citrine/resources/predictor_evaluation_workflow.py b/src/citrine/resources/predictor_evaluation_workflow.py index 1fa257edf..e41e20ebc 100644 --- a/src/citrine/resources/predictor_evaluation_workflow.py +++ b/src/citrine/resources/predictor_evaluation_workflow.py @@ -1,5 +1,6 @@ """Resources that represent both individual and collections of workflow executions.""" -from typing import Optional, Union +from deprecation import deprecated +from typing import Iterator, Optional, Union from uuid import UUID from citrine._rest.collection import Collection @@ -16,6 +17,9 @@ class PredictorEvaluationWorkflowCollection(Collection[PredictorEvaluationWorkfl _collection_key = 'response' _resource = PredictorEvaluationWorkflow + @deprecated(deprecated_in="3.23.0", removed_in="4.0.0", + details="Predictor evaluation workflows are being eliminated in favor of directly" + "evaluating predictors. Please use Project.predictor_evaluations instead.") def __init__(self, project_id: UUID, session: Session): self.project_id: UUID = project_id self.session: Session = session @@ -27,6 +31,8 @@ def build(self, data: dict) -> PredictorEvaluationWorkflow: workflow.project_id = self.project_id return workflow + @deprecated(deprecated_in="3.23.0", removed_in="4.0.0", + details="Please use PredictorEvaluations instead, which doesn't store workflows.") def archive(self, uid: Union[UUID, str]): """Archive a predictor evaluation workflow. @@ -38,6 +44,8 @@ def archive(self, uid: Union[UUID, str]): """ return self._put_resource_ref('archive', uid) + @deprecated(deprecated_in="3.23.0", removed_in="4.0.0", + details="Please use PredictorEvaluations instead, which doesn't store workflows.") def restore(self, uid: Union[UUID, str] = None): """Restore an archived predictor evaluation workflow. @@ -49,11 +57,15 @@ def restore(self, uid: Union[UUID, str] = None): """ return self._put_resource_ref('restore', uid) + @deprecated(deprecated_in="3.23.0", removed_in="4.0.0") def delete(self, uid: Union[UUID, str]) -> Response: """Predictor Evaluation Workflows cannot be deleted; they can be archived instead.""" raise NotImplementedError( "Predictor Evaluation Workflows cannot be deleted; they can be archived instead.") + @deprecated(deprecated_in="3.23.0", removed_in="4.0.0", + details="Please use PredictorEvaluations.trigger_default instead. It doesn't store" + " a workflow, but it triggers an evaluation with the default evaluators.") def create_default(self, *, predictor_id: UUID, @@ -95,3 +107,27 @@ def create_default(self, payload['predictor_version'] = predictor_version data = self.session.post_resource(url, payload) return self.build(data) + + @deprecated(deprecated_in="3.23.0", removed_in="4.0.0", + details="Please use PredictorEvaluations instead, which doesn't store workflows.") + def register(self, model: PredictorEvaluationWorkflow) -> PredictorEvaluationWorkflow: + """Create a new element of the collection by registering an existing resource.""" + return super().register(model) + + @deprecated(deprecated_in="3.23.0", removed_in="4.0.0", + details="Please use PredictorEvaluations instead, which doesn't store workflows.") + def list(self, *, per_page: int = 100) -> Iterator[PredictorEvaluationWorkflow]: + """Paginate over the elements of the collection.""" + return super().list(per_page=per_page) + + @deprecated(deprecated_in="3.23.0", removed_in="4.0.0", + details="Please use PredictorEvaluations instead, which doesn't store workflows.") + def update(self, model: PredictorEvaluationWorkflow) -> PredictorEvaluationWorkflow: + """Update a particular element of the collection.""" + return super().update(model) + + @deprecated(deprecated_in="3.23.0", removed_in="4.0.0", + details="Please use PredictorEvaluations instead, which doesn't store workflows.") + def get(self, uid: Union[UUID, str]) -> PredictorEvaluationWorkflow: + """Get a particular element of the collection.""" + return super().get(uid) diff --git a/src/citrine/resources/project.py b/src/citrine/resources/project.py index 561374c6e..6dedfede6 100644 --- a/src/citrine/resources/project.py +++ b/src/citrine/resources/project.py @@ -40,6 +40,7 @@ PredictorEvaluationExecutionCollection from citrine.resources.predictor_evaluation_workflow import \ PredictorEvaluationWorkflowCollection +from citrine.resources.predictor_evaluation import PredictorEvaluationCollection from citrine.resources.generative_design_execution import \ GenerativeDesignExecutionCollection from citrine.resources.project_member import ProjectMember @@ -148,6 +149,11 @@ def predictor_evaluation_executions(self) -> PredictorEvaluationExecutionCollect """Return a collection representing all visible predictor evaluation executions.""" return PredictorEvaluationExecutionCollection(project_id=self.uid, session=self.session) + @property + def predictor_evaluations(self) -> PredictorEvaluationCollection: + """Return a collection representing all visible predictor evaluations.""" + return PredictorEvaluationCollection(project_id=self.uid, session=self.session) + @property def design_workflows(self) -> DesignWorkflowCollection: """Return a collection representing all visible design workflows.""" diff --git a/tests/informatics/test_predictor_evaluations.py b/tests/informatics/test_predictor_evaluations.py new file mode 100644 index 000000000..1ac41b9ee --- /dev/null +++ b/tests/informatics/test_predictor_evaluations.py @@ -0,0 +1,88 @@ +import uuid + +import pytest + +from citrine.informatics.executions.predictor_evaluation import PredictorEvaluation, PredictorEvaluationRequest, PredictorEvaluatorsResponse +from citrine.informatics.predictor_evaluator import CrossValidationEvaluator +from citrine.informatics.predictor_evaluation_metrics import NDME +from citrine.informatics.predictor_evaluation_result import PredictorEvaluationResult +from citrine._rest.resource import PredictorRef +from tests.utils.session import FakeCall, FakeSession + + +@pytest.fixture +def cross_validation_evaluator(): + yield CrossValidationEvaluator("foo", description="desc", responses={"dk"}, n_folds=2, n_trials=5, metrics={NDME()}) + + +@pytest.fixture +def predictor_ref(): + yield PredictorRef(uuid.uuid4(), 4) + + +@pytest.fixture +def predictor_evaluators_response(cross_validation_evaluator): + yield PredictorEvaluatorsResponse([cross_validation_evaluator]) + + +@pytest.fixture +def predictor_evaluation_request(cross_validation_evaluator, predictor_ref): + yield PredictorEvaluationRequest(evaluators=[cross_validation_evaluator], predictor_id=predictor_ref.uid, predictor_version=predictor_ref.version) + + +@pytest.fixture +def predictor_evaluation(cross_validation_evaluator, predictor_ref): + evaluation = PredictorEvaluation() + evaluation.uid = uuid.uuid4() + evaluation.evaluators = [cross_validation_evaluator] + evaluation.predictor_id = predictor_ref.uid + evaluation.predictor_version = predictor_ref.version + evaluation.status = 'SUCCEEDED' + evaluation.status_description = 'COMPLETED' + yield evaluation + + +def test_predictor_evaluator_response(predictor_evaluators_response, cross_validation_evaluator): + assert predictor_evaluators_response.evaluators == [cross_validation_evaluator] + + +def test_predictor_evaluator_request(predictor_evaluation_request, cross_validation_evaluator, predictor_ref): + assert predictor_evaluation_request.evaluators == [cross_validation_evaluator] + assert predictor_evaluation_request.predictor.dump() == predictor_ref.dump() + + +def test_predictor_evaluation(predictor_evaluation, cross_validation_evaluator, predictor_ref): + assert predictor_evaluation.evaluators == [cross_validation_evaluator] + assert predictor_evaluation.evaluator_names == [cross_validation_evaluator.name] + assert predictor_evaluation.predictor_id == predictor_ref.uid + assert predictor_evaluation.predictor_version == predictor_ref.version + assert predictor_evaluation.status == 'SUCCEEDED' + assert predictor_evaluation.status_description == 'COMPLETED' + assert predictor_evaluation.status_detail == [] + + +def test_results(predictor_evaluation, example_cv_result_dict): + session = FakeSession() + predictor_evaluation._session = session + predictor_evaluation.project_id = uuid.uuid4() + + session.set_response(example_cv_result_dict) + + results = predictor_evaluation["Example Evaluator"] + + expected_call = FakeCall( + method='GET', + path=f'/projects/{predictor_evaluation.project_id}/predictor-evaluations/{predictor_evaluation.uid}/results', + params={"evaluator_name": "Example Evaluator"} + ) + assert session.last_call == expected_call + assert results.evaluator == PredictorEvaluationResult.build(example_cv_result_dict).evaluator + + +def test_results_invalid_type(predictor_evaluation): + session = FakeSession() + predictor_evaluation._session = session + predictor_evaluation.project_id = uuid.uuid4() + + with pytest.raises(TypeError): + predictor_evaluation[1] diff --git a/tests/informatics/workflows/test_predictor_evaluation_workflow.py b/tests/informatics/workflows/test_predictor_evaluation_workflow.py index 8c836bf11..eee00a1cc 100644 --- a/tests/informatics/workflows/test_predictor_evaluation_workflow.py +++ b/tests/informatics/workflows/test_predictor_evaluation_workflow.py @@ -36,4 +36,5 @@ def test_execution_error(pew): pew.executions pew.project_id = "foo" - assert pew.executions.project_id == "foo" + with pytest.deprecated_call(): + assert pew.executions.project_id == "foo" diff --git a/tests/resources/test_predictor_evaluation_executions.py b/tests/resources/test_predictor_evaluation_executions.py index 312701ad7..53c548282 100644 --- a/tests/resources/test_predictor_evaluation_executions.py +++ b/tests/resources/test_predictor_evaluation_executions.py @@ -17,11 +17,12 @@ def session() -> FakeSession: @pytest.fixture def collection(session) -> PredictorEvaluationExecutionCollection: - return PredictorEvaluationExecutionCollection( - project_id=uuid.uuid4(), - workflow_id=uuid.uuid4(), - session=session, - ) + with pytest.deprecated_call(): + return PredictorEvaluationExecutionCollection( + project_id=uuid.uuid4(), + workflow_id=uuid.uuid4(), + session=session, + ) @pytest.fixture @@ -37,11 +38,13 @@ def test_basic_methods(workflow_execution, collection): assert "Example evaluator" in list(iter(workflow_execution)) - with pytest.raises(NotImplementedError): - collection.register(workflow_execution) + with pytest.deprecated_call(): + with pytest.raises(NotImplementedError): + collection.register(workflow_execution) - with pytest.raises(NotImplementedError): - collection.update(workflow_execution) + with pytest.deprecated_call(): + with pytest.raises(NotImplementedError): + collection.update(workflow_execution) def test_build_new_execution(collection, predictor_evaluation_execution_dict): @@ -87,7 +90,8 @@ def test_trigger_workflow_execution(collection: PredictorEvaluationExecutionColl session.set_response(predictor_evaluation_execution_dict) # When - actual_execution = collection.trigger(predictor_id, random_state=random_state) + with pytest.deprecated_call(): + actual_execution = collection.trigger(predictor_id, random_state=random_state) # Then assert str(actual_execution.uid) == predictor_evaluation_execution_dict["id"] @@ -110,7 +114,8 @@ def test_trigger_workflow_execution_with_version(collection: PredictorEvaluation session.set_response(predictor_evaluation_execution_dict) # When - actual_execution = collection.trigger(predictor_id, predictor_version=predictor_version) + with pytest.deprecated_call(): + actual_execution = collection.trigger(predictor_id, predictor_version=predictor_version) # Then assert str(actual_execution.uid) == predictor_evaluation_execution_dict["id"] @@ -129,7 +134,8 @@ def test_trigger_workflow_execution_with_version(collection: PredictorEvaluation def test_list(collection: PredictorEvaluationExecutionCollection, session, predictor_version): session.set_response({"page": 1, "per_page": 4, "next": "", "response": []}) predictor_id = uuid.uuid4() - lst = list(collection.list(per_page=4, predictor_id=predictor_id, predictor_version=predictor_version)) + with pytest.deprecated_call(): + lst = list(collection.list(per_page=4, predictor_id=predictor_id, predictor_version=predictor_version)) assert not lst expected_path = '/projects/{}/predictor-evaluation-executions'.format(collection.project_id) @@ -140,17 +146,29 @@ def test_list(collection: PredictorEvaluationExecutionCollection, session, predi def test_archive(workflow_execution, collection): - collection.archive(workflow_execution.uid) + with pytest.deprecated_call(): + collection.archive(workflow_execution.uid) expected_path = '/projects/{}/predictor-evaluation-executions/archive'.format(collection.project_id) assert collection.session.last_call == FakeCall(method='PUT', path=expected_path, json={"module_uid": str(workflow_execution.uid)}) def test_restore(workflow_execution, collection): - collection.restore(workflow_execution.uid) + with pytest.deprecated_call(): + collection.restore(workflow_execution.uid) expected_path = '/projects/{}/predictor-evaluation-executions/restore'.format(collection.project_id) assert collection.session.last_call == FakeCall(method='PUT', path=expected_path, json={"module_uid": str(workflow_execution.uid)}) def test_delete(collection): - with pytest.raises(NotImplementedError): - collection.delete(uuid.uuid4()) + with pytest.deprecated_call(): + with pytest.raises(NotImplementedError): + collection.delete(uuid.uuid4()) + +def test_get(predictor_evaluation_execution_dict, workflow_execution, collection): + collection.session.set_response(predictor_evaluation_execution_dict) + + with pytest.deprecated_call(): + execution = collection.get(workflow_execution.uid) + + expected_path = f'/projects/{collection.project_id}/predictor-evaluation-executions/{workflow_execution.uid}' + assert collection.session.last_call == FakeCall(method='GET', path=expected_path) diff --git a/tests/resources/test_predictor_evaluation_workflows.py b/tests/resources/test_predictor_evaluation_workflows.py index d5aa8ab74..d231a583c 100644 --- a/tests/resources/test_predictor_evaluation_workflows.py +++ b/tests/resources/test_predictor_evaluation_workflows.py @@ -15,10 +15,11 @@ def session() -> FakeSession: @pytest.fixture def collection(session) -> PredictorEvaluationWorkflowCollection: - return PredictorEvaluationWorkflowCollection( - project_id=uuid.uuid4(), - session=session, - ) + with pytest.deprecated_call(): + return PredictorEvaluationWorkflowCollection( + project_id=uuid.uuid4(), + session=session, + ) @pytest.fixture @@ -33,22 +34,25 @@ def test_basic_methods(workflow, collection): def test_archive(workflow, collection): - collection.archive(workflow.uid) + with pytest.deprecated_call(): + collection.archive(workflow.uid) expected_path = '/projects/{}/predictor-evaluation-workflows/archive'.format(collection.project_id) assert collection.session.last_call == FakeCall(method='PUT', path=expected_path, json={"module_uid": str(workflow.uid)}) def test_restore(workflow, collection): - collection.restore(workflow.uid) + with pytest.deprecated_call(): + collection.restore(workflow.uid) expected_path = '/projects/{}/predictor-evaluation-workflows/restore'.format(collection.project_id) assert collection.session.last_call == FakeCall(method='PUT', path=expected_path, json={"module_uid": str(workflow.uid)}) def test_delete(collection): - with pytest.raises(NotImplementedError): - collection.delete(uuid.uuid4()) + with pytest.deprecated_call(): + with pytest.raises(NotImplementedError): + collection.delete(uuid.uuid4()) @pytest.mark.parametrize("predictor_version", (2, "1", "latest", None)) @@ -60,11 +64,13 @@ def test_create_default(predictor_evaluation_workflow_dict: dict, session = FakeSession() session.set_response(predictor_evaluation_workflow_dict) - collection = PredictorEvaluationWorkflowCollection( - project_id=project_id, - session=session - ) - default_workflow = collection.create_default(predictor_id=predictor_id, predictor_version=predictor_version) + with pytest.deprecated_call(): + collection = PredictorEvaluationWorkflowCollection( + project_id=project_id, + session=session + ) + with pytest.deprecated_call(): + default_workflow = collection.create_default(predictor_id=predictor_id, predictor_version=predictor_version) url = f'/projects/{collection.project_id}/predictor-evaluation-workflows/default' @@ -73,3 +79,39 @@ def test_create_default(predictor_evaluation_workflow_dict: dict, expected_payload["predictor_version"] = predictor_version assert session.calls == [FakeCall(method="POST", path=url, json=expected_payload)] assert default_workflow.dump() == workflow.dump() + +def test_register(predictor_evaluation_workflow_dict, workflow, collection): + collection.session.set_response(predictor_evaluation_workflow_dict) + + with pytest.deprecated_call(): + collection.register(workflow) + + expected_path = f'/projects/{collection.project_id}/predictor-evaluation-workflows' + assert collection.session.last_call == FakeCall(method='POST', path=expected_path, json=workflow.dump()) + +def test_list(predictor_evaluation_workflow_dict, workflow, collection): + collection.session.set_response({"page": 1, "per_page": 4, "next": "", "response": [predictor_evaluation_workflow_dict]}) + + with pytest.deprecated_call(): + list(collection.list(per_page=20)) + + expected_path = f'/projects/{collection.project_id}/predictor-evaluation-workflows' + assert collection.session.last_call == FakeCall(method='GET', path=expected_path, params={"per_page": 20, "page": 1}) + +def test_update(predictor_evaluation_workflow_dict, workflow, collection): + collection.session.set_response(predictor_evaluation_workflow_dict) + + with pytest.deprecated_call(): + collection.update(workflow) + + expected_path = f'/projects/{collection.project_id}/predictor-evaluation-workflows/{workflow.uid}' + assert collection.session.last_call == FakeCall(method='PUT', path=expected_path, json=workflow.dump()) + +def test_get(predictor_evaluation_workflow_dict, workflow, collection): + collection.session.set_response(predictor_evaluation_workflow_dict) + + with pytest.deprecated_call(): + collection.get(workflow.uid) + + expected_path = f'/projects/{collection.project_id}/predictor-evaluation-workflows/{workflow.uid}' + assert collection.session.last_call == FakeCall(method='GET', path=expected_path) diff --git a/tests/resources/test_predictor_evaluations.py b/tests/resources/test_predictor_evaluations.py new file mode 100644 index 000000000..73e6789d8 --- /dev/null +++ b/tests/resources/test_predictor_evaluations.py @@ -0,0 +1,249 @@ +import uuid + +import pytest + +from citrine.resources.predictor_evaluation import PredictorEvaluationCollection +from citrine.informatics.executions.predictor_evaluation import PredictorEvaluationRequest +from citrine.informatics.predictors import GraphPredictor + +from tests.utils.factories import CrossValidationEvaluatorFactory, PredictorEvaluationDataFactory,\ + PredictorEvaluationFactory, PredictorInstanceDataFactory, PredictorRefFactory +from tests.utils.session import FakeCall, FakeSession + + +def paging_response(*items): + return {"response": items} + + +def test_get(): + evaluation_response = PredictorEvaluationFactory() + id = uuid.uuid4() + + session = FakeSession() + pec = PredictorEvaluationCollection(uuid.uuid4(), session) + + session.set_response(evaluation_response) + + pec.get(id) + + expected_call = FakeCall( + method='GET', + path=f'/projects/{pec.project_id}/predictor-evaluations/{id}', + params={} + ) + assert session.num_calls == 1 + assert expected_call == session.last_call + + +def test_archived(): + evaluation_response = PredictorEvaluationFactory(is_archived=True) + id = uuid.uuid4() + + session = FakeSession() + pec = PredictorEvaluationCollection(uuid.uuid4(), session) + + session.set_response(evaluation_response) + + pec.archive(id) + + expected_call = FakeCall( + method='PUT', + path=f'/projects/{pec.project_id}/predictor-evaluations/{id}/archive', + json={} + ) + assert session.num_calls == 1 + assert expected_call == session.last_call + + +def test_restore(): + evaluation_response = PredictorEvaluationFactory() + id = uuid.uuid4() + + session = FakeSession() + pec = PredictorEvaluationCollection(uuid.uuid4(), session) + + session.set_response(evaluation_response) + + pec.restore(id) + + expected_call = FakeCall( + method='PUT', + path=f'/projects/{pec.project_id}/predictor-evaluations/{id}/restore', + json={} + ) + assert session.num_calls == 1 + assert expected_call == session.last_call + + +def test_list(): + evaluation_response = PredictorEvaluationFactory() + pred_id = uuid.uuid4() + pred_ver = 2 + + session = FakeSession() + pec = PredictorEvaluationCollection(uuid.uuid4(), session) + + session.set_response(paging_response(evaluation_response)) + + evaluations = list(pec.list(predictor_id=pred_id, predictor_version=pred_ver)) + + expected_call = FakeCall( + method='GET', + path=f'/projects/{pec.project_id}/predictor-evaluations', + params={"page": 1, "per_page": 100, "predictor_id": str(pred_id), "predictor_version": pred_ver, "archived": False} + ) + + assert session.num_calls == 1 + assert expected_call == session.last_call + assert len(evaluations) == 1 + + +def test_list_archived(): + evaluation_response = PredictorEvaluationFactory(is_archived=True) + pred_id = uuid.uuid4() + pred_ver = 2 + + session = FakeSession() + pec = PredictorEvaluationCollection(uuid.uuid4(), session) + + session.set_response(paging_response(evaluation_response)) + + evaluations = list(pec.list_archived(predictor_id=pred_id, predictor_version=pred_ver)) + + expected_call = FakeCall( + method='GET', + path=f'/projects/{pec.project_id}/predictor-evaluations', + params={"page": 1, "per_page": 100, "predictor_id": str(pred_id), "predictor_version": pred_ver, "archived": True} + ) + assert session.num_calls == 1 + assert expected_call == session.last_call + assert len(evaluations) == 1 + + +def test_list_all(): + evaluations = [PredictorEvaluationFactory(), PredictorEvaluationFactory(is_archived=True)] + pred_id = uuid.uuid4() + pred_ver = 2 + + session = FakeSession() + pec = PredictorEvaluationCollection(uuid.uuid4(), session) + + session.set_response(paging_response(*evaluations)) + + evaluations = list(pec.list_all(predictor_id=pred_id, predictor_version=pred_ver)) + + expected_call = FakeCall( + method='GET', + path=f'/projects/{pec.project_id}/predictor-evaluations', + params={"page": 1, "per_page": 100, "predictor_id": str(pred_id), "predictor_version": pred_ver, "archived": None} + ) + assert session.num_calls == 1 + assert expected_call == session.last_call + assert len(evaluations) == 2 + + +def test_trigger(): + evaluators = [CrossValidationEvaluatorFactory()] + pred_ref = PredictorRefFactory() + evaluation_response = PredictorEvaluationFactory() + + session = FakeSession() + pec = PredictorEvaluationCollection(uuid.uuid4(), session) + + session.set_response(evaluation_response) + + pec.trigger(predictor_id=pred_ref["predictor_id"], predictor_version=pred_ref["predictor_version"], evaluators=evaluators) + + expected_payload = PredictorEvaluationRequest(evaluators=evaluators, + predictor_id=pred_ref["predictor_id"], + predictor_version=pred_ref["predictor_version"]) + expected_call = FakeCall( + method='POST', + path=f'/projects/{pec.project_id}/predictor-evaluations/trigger', + json=expected_payload.dump() + ) + assert session.num_calls == 1 + assert expected_call == session.last_call + + +def test_trigger_default(): + evaluation_response = PredictorEvaluationFactory() + pred_ref = PredictorRefFactory() + + session = FakeSession() + pec = PredictorEvaluationCollection(uuid.uuid4(), session) + + session.set_response(evaluation_response) + + pec.trigger_default(predictor_id=pred_ref["predictor_id"], predictor_version=pred_ref["predictor_version"]) + + expected_call = FakeCall( + method='POST', + path=f'/projects/{pec.project_id}/predictor-evaluations/trigger-default', + json=pred_ref + ) + assert session.num_calls == 1 + assert expected_call == session.last_call + + +def test_default(): + response = PredictorEvaluationDataFactory() + pred_ref = PredictorRefFactory() + + session = FakeSession() + pec = PredictorEvaluationCollection(uuid.uuid4(), session) + + session.set_response(response) + + default_evaluators = pec.default(predictor_id=pred_ref["predictor_id"], predictor_version=pred_ref["predictor_version"]) + + expected_call = FakeCall( + method='POST', + path=f'/projects/{pec.project_id}/predictor-evaluations/default', + json=pred_ref + ) + assert session.num_calls == 1 + assert expected_call == session.last_call + assert len(default_evaluators) == len(response["evaluators"]) + +def test_default_from_config(valid_graph_predictor_data): + response = PredictorEvaluationDataFactory() + config = GraphPredictor.build(valid_graph_predictor_data) + payload = config.dump()['instance'] + + session = FakeSession() + pec = PredictorEvaluationCollection(uuid.uuid4(), session) + + session.set_response(response) + + default_evaluators = pec.default_from_config(config) + + expected_call = FakeCall( + method='POST', + path=f'/projects/{pec.project_id}/predictor-evaluations/default-from-config', + json=payload + ) + assert session.num_calls == 1 + assert expected_call == session.last_call + assert len(default_evaluators) == len(response["evaluators"]) + + +def test_register_not_implemented(): + session = FakeSession() + pec = PredictorEvaluationCollection(uuid.uuid4(), session) + with pytest.raises(NotImplementedError): + pec.register(PredictorEvaluationDataFactory()) + + +def test_update_not_implemented(): + session = FakeSession() + pec = PredictorEvaluationCollection(uuid.uuid4(), session) + with pytest.raises(NotImplementedError): + pec.update(PredictorEvaluationDataFactory()) + + +def test_delete_not_implemented(): + session = FakeSession() + pec = PredictorEvaluationCollection(uuid.uuid4(), session) + with pytest.raises(NotImplementedError): + pec.delete(uuid.uuid4()) diff --git a/tests/resources/test_project.py b/tests/resources/test_project.py index 11bb97fa2..cf1053ec0 100644 --- a/tests/resources/test_project.py +++ b/tests/resources/test_project.py @@ -304,14 +304,21 @@ def test_predictors_get_project_id(project): def test_pe_workflows_get_project_id(project): - assert project.uid == project.predictor_evaluation_workflows.project_id + with pytest.deprecated_call(): + assert project.uid == project.predictor_evaluation_workflows.project_id def test_pe_executions_get_project_id(project): - assert project.uid == project.predictor_evaluation_executions.project_id + with pytest.deprecated_call(): + assert project.uid == project.predictor_evaluation_executions.project_id # The resulting collection cannot be used to trigger executions. - with pytest.raises(RuntimeError): - project.predictor_evaluation_executions.trigger(uuid.uuid4()) + with pytest.deprecated_call(): + with pytest.raises(RuntimeError): + project.predictor_evaluation_executions.trigger(uuid.uuid4()) + + +def test_predictor_evaluations_get_project_id(project): + assert project.uid == project.predictor_evaluations.project_id def test_design_workflows_get_project_id(project): diff --git a/tests/utils/factories.py b/tests/utils/factories.py index eb083de83..dc7f96399 100644 --- a/tests/utils/factories.py +++ b/tests/utils/factories.py @@ -463,12 +463,66 @@ class AsyncDefaultPredictorResponseFactory(factory.DictFactory): data = factory.SubFactory(AsyncDefaultPredictorResponseDataFactory) -class PredictorEvaluationWorkflowDataFactory(factory.DictFactory): +class RMSEFactory(factory.DictFactory): + type = "RMSE" + + +class NDMEFactory(factory.DictFactory): + type = "NDME" + + +class RSquaredFactory(factory.DictFactory): + type = "RSquared" + + +class StandardRMSEFactory(factory.DictFactory): + type = "StandardRMSE" + + +class PVALFactory(factory.DictFactory): + type = "PVA" + + +class F1Factory(factory.DictFactory): + type = "F1" + + +class AreaUnderROCFactory(factory.DictFactory): + type = "AreaUnderROC" + + +class CoverageProbabilityFactory(factory.DictFactory): + class Meta: + exclude = ("_level", ) + + _level = factory.Faker('pyfloat', max_value=1, min_value=0) + coverage_level = factory.LazyAttribute(lambda o: str(o._level)) + type = "CoverageProbability" + + +class CrossValidationEvaluatorFactory(factory.DictFactory): + name = factory.Faker("company") + description = factory.Faker("catch_phrase") + responses = factory.List(3 * [factory.Faker('company')]) + n_folds = factory.Faker('random_digit_not_null') + n_trials = factory.Faker('random_digit_not_null') + metrics = factory.List([factory.SubFactory(RMSEFactory), + factory.SubFactory(NDMEFactory), + factory.SubFactory(RSquaredFactory), + factory.SubFactory(StandardRMSEFactory), + factory.SubFactory(PVALFactory), + factory.SubFactory(F1Factory), + factory.SubFactory(AreaUnderROCFactory), + factory.SubFactory(CoverageProbabilityFactory)]) + type = "CrossValidationEvaluator" + + +class PredictorEvaluationWorkflowFactory(factory.DictFactory): id = factory.Faker('uuid4') name = factory.Faker("company") description = factory.Faker("catch_phrase") archived = False - evaluators = [] # TODO Create EvaluatorDataFactory + evaluators = factory.List([factory.SubFactory(CrossValidationEvaluatorFactory)]) type = "PredictorEvaluationWorkflow" # TODO Create Trait and status_detail content status = "SUCCEEDED" @@ -476,6 +530,28 @@ class PredictorEvaluationWorkflowDataFactory(factory.DictFactory): status_detail = [] +class PredictorEvaluationDataFactory(factory.DictFactory): + evaluators = factory.List([factory.SubFactory(CrossValidationEvaluatorFactory)]) + + +class PredictorEvaluationMetadataFactory(factory.DictFactory): + class Meta: + exclude = ('is_archived', ) + + created = factory.SubFactory(UserTimestampDataFactory) + updated = factory.SubFactory(UserTimestampDataFactory) + archived = factory.Maybe('is_archived', factory.SubFactory(UserTimestampDataFactory), None) + predictor_id = factory.Faker("uuid4") + predictor_version = factory.Faker("random_digit_not_null") + status = {"major": "SUCCEEDED", "minor": "READY", "detail": []} + + +class PredictorEvaluationFactory(factory.DictFactory): + id = factory.Faker('uuid4') + data = factory.SubFactory(PredictorEvaluationDataFactory) + metadata = factory.SubFactory(PredictorEvaluationMetadataFactory) + + class DesignSpaceConfigDataFactory(factory.DictFactory): id = factory.Faker('uuid4') name = factory.Faker("company") diff --git a/tests/utils/fakes/__init__.py b/tests/utils/fakes/__init__.py index 0914962b1..6176ea9f5 100644 --- a/tests/utils/fakes/__init__.py +++ b/tests/utils/fakes/__init__.py @@ -6,5 +6,5 @@ from .fake_table_collection import * from .fake_workflows import * from .fake_module_collection import FakeDesignSpaceCollection, FakePredictorCollection -from .fake_workflow_collection import FakeDesignWorkflowCollection, FakePredictorEvaluationWorkflowCollection +from .fake_workflow_collection import FakeDesignWorkflowCollection from .fake_project_collection import * diff --git a/tests/utils/fakes/fake_execution_collection.py b/tests/utils/fakes/fake_execution_collection.py index 55005955a..1719e4ece 100644 --- a/tests/utils/fakes/fake_execution_collection.py +++ b/tests/utils/fakes/fake_execution_collection.py @@ -1,11 +1,10 @@ from uuid import UUID from typing import Optional -from citrine.informatics.executions import DesignExecution, PredictorEvaluationExecution +from citrine.informatics.executions import DesignExecution from citrine.informatics.scores import Score from citrine.resources.design_execution import DesignExecutionCollection -from citrine.resources.predictor_evaluation_execution import PredictorEvaluationExecutionCollection class FakeDesignExecutionCollection(DesignExecutionCollection): @@ -15,9 +14,3 @@ def trigger(self, execution_input: Score, max_candidates: Optional[int] = None) execution.score = execution_input execution.descriptors = [] return execution - - -class FakePredictorEvaluationExecutionCollection(PredictorEvaluationExecutionCollection): - - def trigger(self, predictor_id: UUID) -> PredictorEvaluationExecution: - return PredictorEvaluationExecution() diff --git a/tests/utils/fakes/fake_project_collection.py b/tests/utils/fakes/fake_project_collection.py index dde806191..a0f6a657f 100644 --- a/tests/utils/fakes/fake_project_collection.py +++ b/tests/utils/fakes/fake_project_collection.py @@ -6,8 +6,7 @@ from tests.utils.fakes import FakeDatasetCollection from tests.utils.fakes import FakeDesignSpaceCollection, FakeDesignWorkflowCollection from tests.utils.fakes import FakeGemTableCollection, FakeTableConfigCollection -from tests.utils.fakes import FakePredictorCollection, FakePredictorEvaluationWorkflowCollection -from tests.utils.fakes import FakePredictorEvaluationExecutionCollection +from tests.utils.fakes import FakePredictorCollection from tests.utils.fakes import FakeDescriptorMethods from tests.utils.session import FakeSession @@ -67,8 +66,6 @@ def __init__(self, name="foo", description="bar", num_properties=3, session=Fake self._descriptor_methods = FakeDescriptorMethods(num_properties) self._datasets = FakeDatasetCollection(team_id=self.team_id, session=self.session) self._predictors = FakePredictorCollection(self.uid, self.session) - self._pees = FakePredictorEvaluationExecutionCollection(self.uid, self.session) - self._pews = FakePredictorEvaluationWorkflowCollection(self.uid, self.session) self._tables = FakeGemTableCollection(team_id=self.team_id, project_id=self.uid, session=self.session) self._table_configs = FakeTableConfigCollection(team_id=self.team_id, project_id=self.uid, session=self.session) @@ -92,14 +89,6 @@ def descriptors(self) -> FakeDescriptorMethods: def predictors(self) -> FakePredictorCollection: return self._predictors - @property - def predictor_evaluation_executions(self) -> FakePredictorEvaluationExecutionCollection: - return self._pees - - @property - def predictor_evaluation_workflows(self) -> FakePredictorEvaluationWorkflowCollection: - return self._pews - @property def tables(self) -> FakeGemTableCollection: return self._tables diff --git a/tests/utils/fakes/fake_workflow_collection.py b/tests/utils/fakes/fake_workflow_collection.py index 89651dfa1..6c1ba2587 100644 --- a/tests/utils/fakes/fake_workflow_collection.py +++ b/tests/utils/fakes/fake_workflow_collection.py @@ -3,11 +3,10 @@ from citrine._session import Session from citrine._utils.functions import migrate_deprecated_argument -from citrine.informatics.workflows import PredictorEvaluationWorkflow, DesignWorkflow -from citrine.resources.predictor_evaluation_workflow import PredictorEvaluationWorkflowCollection +from citrine.informatics.workflows import DesignWorkflow from citrine.resources.design_workflow import DesignWorkflowCollection -from tests.utils.fakes import FakeCollection, FakePredictorEvaluationWorkflow +from tests.utils.fakes import FakeCollection WorkflowType = TypeVar('WorkflowType', bound='Workflow') @@ -34,17 +33,3 @@ def archive(self, uid: Union[UUID, str]): class FakeDesignWorkflowCollection(FakeWorkflowCollection[DesignWorkflow], DesignWorkflowCollection): pass - - -class FakePredictorEvaluationWorkflowCollection(FakeWorkflowCollection[PredictorEvaluationWorkflow], PredictorEvaluationWorkflowCollection): - - def create_default(self, *, predictor_id: UUID) -> PredictorEvaluationWorkflow: - pew = FakePredictorEvaluationWorkflow( - name=f"Default predictor evaluation workflow", - description="", - evaluators=[] - ) - pew.project_id = self.project_id - pew.uid = uuid4() - pew._session = self.session - return pew diff --git a/tests/utils/fakes/fake_workflows.py b/tests/utils/fakes/fake_workflows.py index a1f3db1b3..e1afc1d44 100644 --- a/tests/utils/fakes/fake_workflows.py +++ b/tests/utils/fakes/fake_workflows.py @@ -1,6 +1,6 @@ -from citrine.informatics.workflows import DesignWorkflow, PredictorEvaluationWorkflow +from citrine.informatics.workflows import DesignWorkflow -from tests.utils.fakes import FakeDesignExecutionCollection, FakePredictorEvaluationExecutionCollection +from tests.utils.fakes import FakeDesignExecutionCollection class FakeDesignWorkflow(DesignWorkflow): @@ -12,14 +12,3 @@ def design_executions(self) -> FakeDesignExecutionCollection: raise AttributeError('Cannot initialize execution without project reference!') return FakeDesignExecutionCollection( project_id=self.project_id, session=self._session, workflow_id=self.uid) - - -class FakePredictorEvaluationWorkflow(PredictorEvaluationWorkflow): - - @property - def executions(self) -> FakePredictorEvaluationExecutionCollection: - """Return a resource representing all visible executions of this workflow.""" - if getattr(self, 'project_id', None) is None: - raise AttributeError('Cannot initialize execution without project reference!') - return FakePredictorEvaluationExecutionCollection( - project_id=self.project_id, session=self._session, workflow_id=self.uid) \ No newline at end of file