From 90ab19690024575df26f40a6b391564f257da7b9 Mon Sep 17 00:00:00 2001 From: Ken Kroenlein Date: Tue, 12 Nov 2024 14:30:12 -0700 Subject: [PATCH] PR suggestions --- src/citrine/_session.py | 2 + src/citrine/resources/project.py | 67 +++++++++++++++++++------------- tests/resources/test_project.py | 41 ++----------------- tests/utils/session.py | 12 +++--- 4 files changed, 52 insertions(+), 70 deletions(-) diff --git a/src/citrine/_session.py b/src/citrine/_session.py index e7033cf35..c403a574b 100644 --- a/src/citrine/_session.py +++ b/src/citrine/_session.py @@ -151,6 +151,8 @@ def checked_request(self, method: str, path: str, logger.debug('\tmethod: {}'.format(method)) logger.debug('\tpath: {}'.format(path)) logger.debug('\tversion: {}'.format(version)) + for k, v in kwargs.items(): + logger.debug(f'\t{k}: {v}') if self._is_access_token_expired(): self._refresh_access_token() diff --git a/src/citrine/resources/project.py b/src/citrine/resources/project.py index b9d75da37..08108aa33 100644 --- a/src/citrine/resources/project.py +++ b/src/citrine/resources/project.py @@ -1,9 +1,9 @@ """Resources that represent both individual and collections of projects.""" -from functools import partial +from deprecation import deprecated from typing import Optional, Dict, List, Union, Iterable, Tuple, Iterator from uuid import UUID +from warnings import warn -from deprecation import deprecated from gemd.entity.base_entity import BaseEntity from gemd.entity.link_by_uid import LinkByUID @@ -46,7 +46,6 @@ from citrine.resources.project_member import ProjectMember from citrine.resources.response import Response from citrine.resources.table_config import TableConfigCollection -from warnings import warn class Project(Resource['Project']): @@ -519,14 +518,16 @@ class ProjectCollection(Collection[Project]): """ - _path_template = '/projects' + @property + def _path_template(self): + if self.team_id is None: + return '/projects' + else: + return '/teams/{team_id}/projects' _individual_key = 'project' _collection_key = 'projects' _resource = Project - - @property - def _api_version(self): - return 'v3' + _api_version = 'v3' def __init__(self, session: Session, *, team_id: Optional[UUID] = None): self.session = session @@ -553,6 +554,22 @@ def build(self, data) -> Project: project.team_id = self.team_id return project + def get(self, uid: Union[UUID, str]) -> Project: + """ + Get a particular project. + + Parameters + ---------- + uid: UUID or str + The uid of the project to get. + + """ + # Only the team-agnostic project get is implemented + if self.team_id is None: + return super().get(uid) + else: + return ProjectCollection(session=self.session).get(uid) + def register(self, name: str, *, description: Optional[str] = None) -> Project: """ Create and upload new project. @@ -564,15 +581,19 @@ def register(self, name: str, *, description: Optional[str] = None) -> Project: description: str Long-form description of the project to be created. + Return + ------- + Project + The newly registered project. + """ if self.team_id is None: raise NotImplementedError("Cannot register a project without a team ID. " "Use team.projects.register.") - path = format_escaped_url('teams/{team_id}/projects', team_id=self.team_id) project = Project(name, description=description) try: - data = self.session.post_resource(path, project.dump(), version=self._api_version) + data = self.session.post_resource(self._get_path(), project.dump()) data = data[self._individual_key] return self.build(data) except NonRetryableException as e: @@ -595,15 +616,7 @@ def list(self, *, per_page: int = 1000) -> Iterator[Project]: Projects in this collection. """ - if self.team_id is None: - path = '/projects' - else: - path = format_escaped_url('/teams/{team_id}/projects', team_id=self.team_id) - - fetcher = partial(self._fetch_page, path=path) - return self._paginator.paginate(page_fetcher=fetcher, - collection_builder=self._build_collection_elements, - per_page=per_page) + return super().list(per_page=per_page) def search_all(self, search_params: Optional[Dict]) -> Iterable[Dict]: """ @@ -647,19 +660,13 @@ def search_all(self, search_params: Optional[Dict]) -> Iterable[Dict]: """ collections = [] - if self.team_id is None: - path = "/projects/search" - else: - path = format_escaped_url("/teams/{team_id}/projects/search", team_id=self.team_id) - query_params = {'userId': ""} json = {} if search_params is None else {'search_params': search_params} - data = self.session.post_resource(path, + data = self.session.post_resource(self._get_path(action="search"), params=query_params, - json=json, - version=self._api_version) + json=json) if self._collection_key is not None: collections = data[self._collection_key] @@ -734,7 +741,11 @@ def delete(self, uid: Union[UUID, str]) -> Response: If the project is not empty, then the Response will contain a list of all of the project's resources. These must be deleted before the project can be deleted. """ - return super().delete(uid) + # Only the team-agnostic project get is implemented + if self.team_id is None: + return super().delete(uid) + else: + return ProjectCollection(session=self.session).delete(uid) def update(self, model: Project) -> Project: """Projects cannot be updated.""" diff --git a/tests/resources/test_project.py b/tests/resources/test_project.py index 264dee288..28a869860 100644 --- a/tests/resources/test_project.py +++ b/tests/resources/test_project.py @@ -353,36 +353,6 @@ def test_failed_register_no_team(session): project_collection.register("Project") -def test_project_registration(collection: ProjectCollection, session): - # Given - create_time = parse('2019-09-10T00:00:00+00:00') - project_data = ProjectDataFactory( - name='testing', - description='A sample project', - created_at=int(create_time.timestamp() * 1000) # The lib expects ms since epoch, which is really odd - ) - session.set_response({'project': project_data}) - - # When - with pytest.warns(DeprecationWarning): - created_project = collection.register('testing') - - # Then - assert 1 == session.num_calls - expected_call = FakeCall( - method='POST', - path='/projects', - json={ - 'name': 'testing' - } - ) - assert expected_call == session.last_call - - assert 'A sample project' == created_project.description - assert 'CREATED' == created_project.status - assert create_time == created_project.created_at - - def test_project_registration(collection: ProjectCollection, session): # Given create_time = parse('2019-09-10T00:00:00+00:00') @@ -454,7 +424,7 @@ def test_list_no_team(session): projects = list(project_collection.list()) assert 1 == session.num_calls - expected_call = FakeCall(method='GET', path=f'/projects', params={'per_page': 1000, 'page': 1}) + expected_call = FakeCall(method='GET', path='/projects', params={'per_page': 1000, 'page': 1}) assert expected_call == session.last_call assert 5 == len(projects) @@ -598,7 +568,7 @@ def test_delete_project(collection, session): # Then assert 1 == session.num_calls - expected_call = FakeCall(method='DELETE', path='/projects/{}'.format(uid)) + expected_call = FakeCall(method='DELETE', path=f'/projects/{uid}') assert expected_call == session.last_call @@ -628,11 +598,8 @@ def test_list_members(project, session): # Then assert 2 == session.num_calls - expect_call_1 = FakeCall( - method='GET', - path='/teams/{}'.format(team_data['id']), - ) - expect_call_2 = FakeCall(method='GET', path='/teams/{}/users'.format(project.team_id)) + expect_call_1 = FakeCall(method='GET', path=f'/teams/{team_data["id"]}') + expect_call_2 = FakeCall(method='GET', path=f'/teams/{project.team_id}/users') assert expect_call_1 == session.calls[0] assert expect_call_2 == session.calls[1] assert isinstance(members[0], TeamMember) diff --git a/tests/utils/session.py b/tests/utils/session.py index c74d7b278..0a6a7b97d 100644 --- a/tests/utils/session.py +++ b/tests/utils/session.py @@ -36,11 +36,13 @@ def __eq__(self, other) -> bool: if not isinstance(other, FakeCall): return NotImplemented - return self.method == other.method and \ - self.path == other.path and \ - self.json == other.json and \ - self.params == other.params and \ - (not self.version or not other.version or self.version == other.version) # Allows users to check the URL version without forcing everyone to. + return ( + self.method == other.method and + self.path.lstrip('/') == other.path.lstrip('/') and # Leading slashes don't affect results + self.json == other.json and + self.params == other.params and + (not self.version or not other.version or self.version == other.version) # Allows users to check the URL version without forcing everyone to. + ) class FakeSession(Session):