diff --git a/src/citrine/__version__.py b/src/citrine/__version__.py index 9ce9954cf..fcbd6954b 100644 --- a/src/citrine/__version__.py +++ b/src/citrine/__version__.py @@ -1 +1 @@ -__version__ = "3.10.0" +__version__ = "3.11.0" diff --git a/src/citrine/jobs/job.py b/src/citrine/jobs/job.py index df0615408..397a99971 100644 --- a/src/citrine/jobs/job.py +++ b/src/citrine/jobs/job.py @@ -1,7 +1,9 @@ +from gemd.enumeration.base_enumeration import BaseEnumeration from logging import getLogger from time import time, sleep from typing import Union from uuid import UUID +from warnings import warn from citrine._rest.resource import Resource from citrine._serialization.properties import Set as PropertySet, String, Object @@ -23,6 +25,16 @@ class JobSubmissionResponse(Resource['JobSubmissionResponse']): """:UUID: job id of the job submission request""" +class JobStatus(BaseEnumeration): + """The valid status codes for a job.""" + + SUBMITTED = "Submitted" + PENDING = "Pending" + RUNNING = "Running" + SUCCESS = "Success" + FAILURE = "Failure" + + class TaskNode(Resource['TaskNode']): """Individual task status. @@ -33,14 +45,29 @@ class TaskNode(Resource['TaskNode']): """:str: unique identification number for the job task""" task_type = properties.String("task_type") """:str: the type of task running""" - status = properties.String("status") - """:str: The last reported status of this particular task. - One of "Submitted", "Pending", "Running", "Success", or "Failure".""" + _status = properties.String("status") dependencies = PropertySet(String(), "dependencies") """:Set[str]: all the tasks that this task is dependent on""" failure_reason = properties.Optional(String(), "failure_reason") """:str: if a task has failed, the failure reason will be in this parameter""" + @property + def status(self) -> Union[JobStatus, str]: + """The last reported status of this particular task.""" + if resolved := JobStatus.from_str(self._status, exception=False): + return resolved + else: + return self._status + + @status.setter + def status(self, value: Union[JobStatus, str]) -> None: + if JobStatus.from_str(value, exception=False) is None: + warn( + f"{value} is not a recognized JobStatus; this will become an error as of v4.0.0.", + DeprecationWarning + ) + self._status = value + class JobStatusResponse(Resource['JobStatusResponse']): """A response to a job status check. @@ -50,13 +77,37 @@ class JobStatusResponse(Resource['JobStatusResponse']): job_type = properties.String("job_type") """:str: the type of job for this status report""" - status = properties.String("status") + _status = properties.String("status") """:str: The status of the job. One of "Running", "Success", or "Failure".""" tasks = properties.List(Object(TaskNode), "tasks") """:List[TaskNode]: all of the constituent task required to complete this job""" output = properties.Optional(properties.Mapping(String, String), 'output') """:Optional[dict[str, str]]: job output properties and results""" + @property + def status(self) -> Union[JobStatus, str]: + """The last reported status of this particular task.""" + if resolved := JobStatus.from_str(self._status, exception=False): + return resolved + else: + return self._status + + @status.setter + def status(self, value: Union[JobStatus, str]) -> None: + if resolved := JobStatus.from_str(value, exception=False): + if resolved not in [JobStatus.RUNNING, JobStatus.SUCCESS, JobStatus.FAILURE]: + warn( + f"{value} is not a valid JobStatus for a JobStatusResponse; " + f"this will become an error as of v4.0.0.", + DeprecationWarning + ) + else: + warn( + f"{value} is not a recognized JobStatus; this will become an error as of v4.0.0.", + DeprecationWarning + ) + self._status = value + def _poll_for_job_completion(session: Session, job: Union[JobSubmissionResponse, UUID, str], @@ -102,7 +153,7 @@ def _poll_for_job_completion(session: Session, while True: response = session.get_resource(path=path, params=params) status: JobStatusResponse = JobStatusResponse.build(response) - if status.status in ['Success', 'Failure']: + if status.status in [JobStatus.SUCCESS, JobStatus.FAILURE]: break elif time() - start_time < timeout: logger.info( @@ -115,12 +166,12 @@ def _poll_for_job_completion(session: Session, f'Note job on server is unaffected by this timeout.') logger.debug('Last status: {}'.format(status.dump())) raise PollingTimeoutError('Job {} timed out.'.format(job_id)) - if status.status == 'Failure': + if status.status == JobStatus.FAILURE: logger.debug(f'Job terminated with Failure status: {status.dump()}') if raise_errors: failure_reasons = [] for task in status.tasks: - if task.status == 'Failure': + if task.status == JobStatus.FAILURE: logger.error(f'Task {task.id} failed with reason "{task.failure_reason}"') failure_reasons.append(task.failure_reason) raise JobFailureError( diff --git a/src/citrine/resources/ingestion.py b/src/citrine/resources/ingestion.py index 728ba6503..916426db1 100644 --- a/src/citrine/resources/ingestion.py +++ b/src/citrine/resources/ingestion.py @@ -196,6 +196,9 @@ class Ingestion(Resource['Ingestion']): raise_errors = properties.Optional(properties.Boolean(), 'raise_errors', default=True) @property + @deprecated(deprecated_in='3.11.0', removed_in='4.0.0', + details="The project_id attribute is deprecated since " + "dataset access is now controlled through teams.") def project_id(self) -> Optional[UUID]: """[DEPRECATED] The project ID associated with this ingest.""" return self._project_id @@ -300,7 +303,7 @@ def build_objects_async(self, if not build_table: project_id = None elif project is None: - if self.project_id is None: + if self._project_id is None: raise ValueError("Building a table requires a target project.") else: warn( @@ -308,7 +311,7 @@ def build_objects_async(self, "and will be removed in v4. Please pass a project explicitly.", DeprecationWarning ) - project_id = self.project_id + project_id = self._project_id elif isinstance(project, Project): project_id = project.uid elif isinstance(project, UUID): @@ -365,18 +368,26 @@ def poll_for_job_completion(self, if polling_delay is not None: kwargs["polling_delay"] = polling_delay - _poll_for_job_completion( + build_job_status = _poll_for_job_completion( session=self.session, team_id=self.team_id, job=job, raise_errors=False, # JobFailureError doesn't contain the error **kwargs ) + if build_job_status.output is not None and "table_build_job_id" in build_job_status.output: + _poll_for_job_completion( + session=self.session, + team_id=self.team_id, + job=build_job_status.output["table_build_job_id"], + raise_errors=False, # JobFailureError doesn't contain the error + **kwargs + ) return self.status() def status(self) -> IngestionStatus: """ - [ALPHA] Retrieve the status of the ingestion from platform. + [ALPHA] Retrieve the status of the ingestion from platform. Returns ---------- @@ -438,7 +449,7 @@ def poll_for_job_completion(self, def status(self) -> IngestionStatus: """ - [ALPHA] Retrieve the status of the ingestion from platform. + [ALPHA] Retrieve the status of the ingestion from platform. Returns ---------- diff --git a/tests/jobs/test_deprecations.py b/tests/jobs/test_deprecations.py new file mode 100644 index 000000000..af0b870ee --- /dev/null +++ b/tests/jobs/test_deprecations.py @@ -0,0 +1,39 @@ +from citrine.jobs.job import JobStatus, JobStatusResponse, TaskNode +import pytest +import warnings + +from tests.utils.factories import TaskNodeDataFactory, JobStatusResponseDataFactory + +def test_status_response_status(): + status_response = JobStatusResponse.build(JobStatusResponseDataFactory(failure=True)) + assert status_response.status == JobStatus.FAILURE + + with pytest.deprecated_call(): + status_response.status = 'Failed' + with warnings.catch_warnings(): + warnings.simplefilter("error") + assert not isinstance(status_response.status, JobStatus) + + with pytest.deprecated_call(): + status_response.status = JobStatus.PENDING + with warnings.catch_warnings(): + warnings.simplefilter("error") + assert status_response.status == JobStatus.PENDING + + with warnings.catch_warnings(): + warnings.simplefilter("error") + status_response.status = JobStatus.SUCCESS + assert status_response.status == JobStatus.SUCCESS + +def test_task_node_status(): + status_response = TaskNode.build(TaskNodeDataFactory(failure=True)) + assert status_response.status == JobStatus.FAILURE + + with pytest.deprecated_call(): + status_response.status = 'Failed' + assert not isinstance(status_response.status, JobStatus) + + with warnings.catch_warnings(): + warnings.simplefilter("error") + status_response.status = JobStatus.SUCCESS + assert status_response.status == JobStatus.SUCCESS diff --git a/tests/jobs/test_waiting.py b/tests/jobs/test_waiting.py index 0571915f4..62ce79f0d 100644 --- a/tests/jobs/test_waiting.py +++ b/tests/jobs/test_waiting.py @@ -7,8 +7,6 @@ import time from citrine.informatics.executions.design_execution import DesignExecution -from citrine.informatics.executions.predictor_evaluation_execution import ( - PredictorEvaluationExecution) from citrine.jobs.waiting import ( wait_for_asynchronous_object, wait_while_executing, @@ -53,7 +51,7 @@ def test_wait_while_validating_timeout(sleep_mock, time_mock): module.in_progress.return_value = True collection.get.return_value = module - with pytest.raises(ConditionTimeoutError) as exceptio: + with pytest.raises(ConditionTimeoutError): wait_while_validating(collection=collection, module=module, timeout=1.0) @mock.patch('time.sleep', return_value=None) diff --git a/tests/resources/test_file_link.py b/tests/resources/test_file_link.py index c08b78523..0a768a1cc 100644 --- a/tests/resources/test_file_link.py +++ b/tests/resources/test_file_link.py @@ -13,7 +13,10 @@ from citrine.resources.ingestion import Ingestion, IngestionCollection from citrine.exceptions import NotFound -from tests.utils.factories import FileLinkDataFactory, _UploaderFactory +from tests.utils.factories import ( + FileLinkDataFactory, _UploaderFactory, JobStatusResponseDataFactory, + IngestionStatusResponseDataFactory, IngestFilesResponseDataFactory, JobSubmissionResponseDataFactory +) from tests.utils.session import FakeSession, FakeS3Client, FakeCall, FakeRequestResponseApiError @@ -536,31 +539,15 @@ def test_ingest(collection: FileCollection, session): good_file2 = collection.build({"filename": "also.csv", "id": str(uuid4()), "version": str(uuid4())}) bad_file = FileLink(filename="bad.csv", url="http://files.com/input.csv") - ingest_create_resp = { - "team_id": str(uuid4()), - "dataset_id": str(uuid4()), - "ingestion_id": str(uuid4()) - } - job_id_resp = { - 'job_id': str(uuid4()) - } - job_status_resp = { - 'job_id': job_id_resp['job_id'], - 'job_type': 'create-gemd-objects', - 'status': 'Success', - 'tasks': [{'id': f'create-gemd-objects-{uuid4()}', - 'task_type': 'create-gemd-objects-task', - 'status': 'Success', - 'dependencies': [], - 'failure_reason': None}], - 'output': {} - } - ingest_status_resp = { - "ingestion_id": ingest_create_resp["ingestion_id"], - "status": "ingestion_created", - "errors": [], - } - session.set_responses(ingest_create_resp, job_id_resp, job_status_resp, ingest_status_resp) + ingest_files_resp = IngestFilesResponseDataFactory() + job_id_resp = JobSubmissionResponseDataFactory() + job_status_resp = JobStatusResponseDataFactory( + job_id=job_id_resp['job_id'], + job_type='create-gemd-objects', + ) + ingest_status_resp = IngestionStatusResponseDataFactory() + + session.set_responses(ingest_files_resp, job_id_resp, job_status_resp, ingest_status_resp) collection.ingest([good_file1, good_file2]) with pytest.raises(ValueError, match=bad_file.url): @@ -572,7 +559,7 @@ def test_ingest(collection: FileCollection, session): with pytest.raises(ValueError): collection.ingest([good_file1], build_table=True) - session.set_responses(ingest_create_resp, job_id_resp, job_status_resp, ingest_status_resp) + session.set_responses(ingest_files_resp, job_id_resp, job_status_resp, ingest_status_resp) coll_with_project_id = FileCollection(team_id=uuid4(), dataset_id=uuid4(), session=session) coll_with_project_id.project_id = uuid4() with pytest.deprecated_call(): diff --git a/tests/resources/test_gemd_resource.py b/tests/resources/test_gemd_resource.py index ec372f86f..780aa30c9 100644 --- a/tests/resources/test_gemd_resource.py +++ b/tests/resources/test_gemd_resource.py @@ -45,7 +45,7 @@ from citrine._utils.functions import format_escaped_url from tests.utils.factories import MaterialRunDataFactory, MaterialSpecDataFactory -from tests.utils.factories import JobSubmissionResponseFactory +from tests.utils.factories import JobSubmissionResponseDataFactory from tests.utils.session import FakeSession, FakeCall @@ -409,7 +409,7 @@ def test_async_update(gemd_collection, session): 'output': {} } - session.set_responses(JobSubmissionResponseFactory(), fake_job_status_resp) + session.set_responses(JobSubmissionResponseDataFactory(), fake_job_status_resp) # This returns None on successful update with wait. gemd_collection.async_update(obj, wait_for_response=True) @@ -423,7 +423,7 @@ def test_async_update_and_no_dataset_id(gemd_collection, session): uids={'id': str(uuid4())} ) - session.set_response(JobSubmissionResponseFactory()) + session.set_response(JobSubmissionResponseDataFactory()) gemd_collection.dataset_id = None with pytest.raises(RuntimeError): @@ -444,7 +444,7 @@ def test_async_update_timeout(gemd_collection, session): 'output': {} } - session.set_responses(JobSubmissionResponseFactory(), fake_job_status_resp) + session.set_responses(JobSubmissionResponseDataFactory(), fake_job_status_resp) with pytest.raises(PollingTimeoutError): gemd_collection.async_update(obj, wait_for_response=True, @@ -465,7 +465,7 @@ def test_async_update_and_wait(gemd_collection, session): 'output': {} } - session.set_responses(JobSubmissionResponseFactory(), fake_job_status_resp) + session.set_responses(JobSubmissionResponseDataFactory(), fake_job_status_resp) # This returns None on successful update with wait. gemd_collection.async_update(obj, wait_for_response=True) @@ -485,7 +485,7 @@ def test_async_update_and_wait_failure(gemd_collection, session): 'output': {} } - session.set_responses(JobSubmissionResponseFactory(), fake_job_status_resp) + session.set_responses(JobSubmissionResponseDataFactory(), fake_job_status_resp) with pytest.raises(JobFailureError): gemd_collection.async_update(obj, wait_for_response=True) @@ -499,7 +499,7 @@ def test_async_update_with_no_wait(gemd_collection, session): uids={'id': str(uuid4())} ) - session.set_response(JobSubmissionResponseFactory()) + session.set_response(JobSubmissionResponseDataFactory()) job_id = gemd_collection.async_update(obj, wait_for_response=False) assert job_id is not None diff --git a/tests/resources/test_ingestion.py b/tests/resources/test_ingestion.py index e741e4bd6..5b3f791f0 100644 --- a/tests/resources/test_ingestion.py +++ b/tests/resources/test_ingestion.py @@ -6,12 +6,17 @@ from citrine.resources.api_error import ValidationError from citrine.resources.dataset import Dataset from citrine.resources.file_link import FileLink -from citrine.resources.ingestion import Ingestion, IngestionCollection, IngestionStatus, IngestionStatusType, \ - IngestionException, IngestionErrorTrace, IngestionErrorType, IngestionErrorFamily, IngestionErrorLevel +from citrine.resources.ingestion import ( + Ingestion, IngestionCollection, IngestionStatus, IngestionStatusType, IngestionException, + IngestionErrorTrace, IngestionErrorType, IngestionErrorFamily, IngestionErrorLevel +) from citrine.jobs.job import JobSubmissionResponse, JobStatusResponse, JobFailureError from citrine.resources.project import Project -from tests.utils.factories import DatasetFactory +from tests.utils.factories import ( + DatasetFactory, IngestionStatusResponseDataFactory, JobSubmissionResponseDataFactory, + JobStatusResponseDataFactory +) from tests.utils.session import FakeCall, FakeSession, FakeRequestResponseApiError @@ -136,6 +141,8 @@ def _mock_poll_for_job_completion( outer_polling_delay = polling_delay outer_raise_errors = raise_errors + return JobStatusResponse.build(JobStatusResponseDataFactory()) + def _mock_status(self) -> IngestionStatus: return status @@ -156,13 +163,7 @@ def _mock_status(self) -> IngestionStatus: def test_processing_exceptions(session, ingest, monkeypatch): def _mock_poll_for_job_completion(**_): - response = { - "job_type": "Ingestion!!!!! :D", - "status": "Success", - "tasks": [], - "output": dict() - } - return JobStatusResponse.build(response) + return JobStatusResponse.build(JobStatusResponseDataFactory()) # This is mocked equivalently for all tests monkeypatch.setattr("citrine.resources.ingestion._poll_for_job_completion", _mock_poll_for_job_completion) @@ -269,15 +270,15 @@ def test_ingestion_with_table_build(session: FakeSession, deprecated_dataset: Dataset, file_link: FileLink): # build_objects_async will always approve, if we get that far - session.set_responses( - {"job_id": str(uuid4())} - ) + session.set_responses(JobSubmissionResponseDataFactory()) with pytest.raises(ValueError): ingest.build_objects_async(build_table=True) with pytest.deprecated_call(): ingest.project_id = uuid4() + with pytest.deprecated_call(): + assert ingest.project_id is not None with pytest.deprecated_call(): ingest.build_objects_async(build_table=True) with pytest.deprecated_call(): @@ -295,6 +296,27 @@ def test_ingestion_with_table_build(session: FakeSession, ingest.build_objects_async(build_table=True, project=str(project_uuid)) assert session.last_call.params["project_id"] == project_uuid + # full build_objects + full_build_job = JobSubmissionResponseDataFactory() + output = { + 'ingestion_id': str(ingest.uid), + 'gemd_table_config_version': '1', + 'table_build_job_id': str(uuid4()), + 'gemd_table_config_id': str(uuid4()) + } + session.set_responses( + full_build_job, + JobStatusResponseDataFactory( + job_id=full_build_job["job_id"], + output=output, + ), + JobStatusResponseDataFactory(), + IngestionStatusResponseDataFactory() + ) + status = ingest.build_objects(build_table=True, project=str(project_uuid)) + assert status.success + + def test_ingestion_flow(session: FakeSession, ingest: Ingestion, collection: IngestionCollection, @@ -342,18 +364,16 @@ def _raise_exception(): ingest.raise_errors = True session.set_responses( - {"job_id": uuid4()}, - {"job_type": "Ingestion!!!!! :D", "status": "Success", "tasks": [], "output": dict()}, - { - "ingestion_id": ingest.uid, - "status": IngestionStatusType.INGESTION_CREATED, - "errors": [{ + JobSubmissionResponseDataFactory(), + JobStatusResponseDataFactory(), + IngestionStatusResponseDataFactory( + errors=[{ "family": IngestionErrorFamily.DATA, "error_type": IngestionErrorType.MISSING_RAW_FOR_INGREDIENT, "level": IngestionErrorLevel.ERROR, "msg": "Missing ingredient: \"myristic (14:0)\" (Note ingredient IDs are case sensitive)" }] - } + ), ) with pytest.raises(IngestionException, match="Missing ingredient"): ingest.build_objects() diff --git a/tests/utils/factories.py b/tests/utils/factories.py index e19aa69c1..04da2e27a 100644 --- a/tests/utils/factories.py +++ b/tests/utils/factories.py @@ -15,8 +15,10 @@ from citrine.gemd_queries.filter import * from citrine.informatics.scores import LIScore from citrine.informatics.workflows import DesignWorkflow +from citrine.jobs.job import JobStatus from citrine.resources.dataset import Dataset from citrine.resources.file_link import _Uploader +from citrine.resources.ingestion import IngestionStatusType from citrine.resources.material_run import MaterialRun from citrine.resources.material_spec import MaterialSpec from citrine.resources.material_template import MaterialTemplate @@ -525,10 +527,57 @@ class Params: status_description = "" # TODO: Should be None, but property not defined as Optional -class JobSubmissionResponseFactory(factory.DictFactory): +class IngestFilesResponseDataFactory(factory.DictFactory): + team_id = factory.Faker('uuid4') + dataset_id = factory.Faker('uuid4') + ingestion_id = factory.Faker('uuid4') + + +class IngestionStatusResponseDataFactory(factory.DictFactory): + ingestion_id = factory.Faker('uuid4') + status = IngestionStatusType.INGESTION_CREATED + errors = factory.List([]) + + +class JobSubmissionResponseDataFactory(factory.DictFactory): job_id = factory.Faker('uuid4') +class TaskNodeDataFactory(factory.DictFactory): + class Params: + failure = False + + id = factory.Faker('uuid4') + task_type = factory.Faker('word') + status = factory.Maybe( + "failure", + yes_declaration=JobStatus.FAILURE, + no_declaration=JobStatus.SUCCESS + ) + dependencies = factory.List([]) + failure_reason = factory.Maybe( + "failure", + yes_declaration=factory.Faker('sentence'), + no_declaration=None + ) + + +class JobStatusResponseDataFactory(factory.DictFactory): + class Params: + failure = False + + job_type = factory.Faker('word') + status = factory.Maybe( + "failure", + yes_declaration=JobStatus.FAILURE, + no_declaration=JobStatus.SUCCESS + ) + tasks = factory.List([ + factory.RelatedFactory(TaskNodeDataFactory, failure=factory.SelfAttribute('...failure')) + ]) + output = factory.Dict({}) + + class DatasetDataFactory(factory.DictFactory): class Params: times = factory.List([factory.Faker("unix_milliseconds") for i in range(3)])