Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/citrine/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "3.11.5"
__version__ = "3.11.6"
2 changes: 2 additions & 0 deletions src/citrine/_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ def checked_request(self, method: str, path: str,
logger.debug('\tmethod: {}'.format(method))
logger.debug('\tpath: {}'.format(path))
logger.debug('\tversion: {}'.format(version))
for k, v in kwargs.items():
logger.debug(f'\t{k}: {v}')

if self._is_access_token_expired():
self._refresh_access_token()
Expand Down
66 changes: 38 additions & 28 deletions src/citrine/resources/project.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""Resources that represent both individual and collections of projects."""
from functools import partial
from deprecation import deprecated
from typing import Optional, Dict, List, Union, Iterable, Tuple, Iterator
from uuid import UUID
from warnings import warn

from deprecation import deprecated
from gemd.entity.base_entity import BaseEntity
from gemd.entity.link_by_uid import LinkByUID

Expand All @@ -12,7 +12,6 @@
from citrine._serialization import properties
from citrine._session import Session
from citrine._utils.functions import format_escaped_url
from citrine.exceptions import NonRetryableException, ModuleRegistrationFailedException
from citrine.resources.api_error import ApiError
from citrine.resources.branch import BranchCollection
from citrine.resources.dataset import DatasetCollection
Expand Down Expand Up @@ -46,7 +45,6 @@
from citrine.resources.project_member import ProjectMember
from citrine.resources.response import Response
from citrine.resources.table_config import TableConfigCollection
from warnings import warn


class Project(Resource['Project']):
Expand Down Expand Up @@ -519,14 +517,16 @@ class ProjectCollection(Collection[Project]):

"""

_path_template = '/projects'
@property
def _path_template(self):
if self.team_id is None:
return '/projects'
else:
return '/teams/{team_id}/projects'
_individual_key = 'project'
_collection_key = 'projects'
_resource = Project

@property
def _api_version(self):
return 'v3'
_api_version = 'v3'

def __init__(self, session: Session, *, team_id: Optional[UUID] = None):
self.session = session
Expand All @@ -553,6 +553,22 @@ def build(self, data) -> Project:
project.team_id = self.team_id
return project

def get(self, uid: Union[UUID, str]) -> Project:
"""
Get a particular project.

Parameters
----------
uid: UUID or str
The uid of the project to get.

"""
# Only the team-agnostic project get is implemented
if self.team_id is None:
return super().get(uid)
else:
return ProjectCollection(session=self.session).get(uid)

def register(self, name: str, *, description: Optional[str] = None) -> Project:
"""
Create and upload new project.
Expand All @@ -564,19 +580,18 @@ def register(self, name: str, *, description: Optional[str] = None) -> Project:
description: str
Long-form description of the project to be created.

Return
-------
Project
The newly registered project.

"""
if self.team_id is None:
raise NotImplementedError("Cannot register a project without a team ID. "
"Use team.projects.register.")

path = format_escaped_url('teams/{team_id}/projects', team_id=self.team_id)
project = Project(name, description=description)
try:
data = self.session.post_resource(path, project.dump(), version=self._api_version)
data = data[self._individual_key]
return self.build(data)
except NonRetryableException as e:
raise ModuleRegistrationFailedException(project.__class__.__name__, e)
return super().register(project)

def list(self, *, per_page: int = 1000) -> Iterator[Project]:
"""
Expand All @@ -595,15 +610,7 @@ def list(self, *, per_page: int = 1000) -> Iterator[Project]:
Projects in this collection.

"""
if self.team_id is None:
path = '/projects'
else:
path = format_escaped_url('/teams/{team_id}/projects', team_id=self.team_id)

fetcher = partial(self._fetch_page, path=path)
return self._paginator.paginate(page_fetcher=fetcher,
collection_builder=self._build_collection_elements,
per_page=per_page)
return super().list(per_page=per_page)

def search_all(self, search_params: Optional[Dict]) -> Iterable[Dict]:
"""
Expand Down Expand Up @@ -647,12 +654,11 @@ def search_all(self, search_params: Optional[Dict]) -> Iterable[Dict]:

"""
collections = []
path = self._get_path(action="search")
query_params = {'userId': ""}

json = {} if search_params is None else {'search_params': search_params}

data = self.session.post_resource(path,
data = self.session.post_resource(self._get_path(action="search"),
params=query_params,
json=json,
version=self._api_version)
Expand Down Expand Up @@ -730,7 +736,11 @@ def delete(self, uid: Union[UUID, str]) -> Response:
If the project is not empty, then the Response will contain a list of all of the project's
resources. These must be deleted before the project can be deleted.
"""
return super().delete(uid)
# Only the team-agnostic project delete is implemented
if self.team_id is None:
return super().delete(uid)
else:
return ProjectCollection(session=self.session).delete(uid)

def update(self, model: Project) -> Project:
"""Projects cannot be updated."""
Expand Down
70 changes: 29 additions & 41 deletions tests/resources/test_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,36 +353,6 @@ def test_failed_register_no_team(session):
project_collection.register("Project")


def test_project_registration(collection: ProjectCollection, session):
# Given
create_time = parse('2019-09-10T00:00:00+00:00')
project_data = ProjectDataFactory(
name='testing',
description='A sample project',
created_at=int(create_time.timestamp() * 1000) # The lib expects ms since epoch, which is really odd
)
session.set_response({'project': project_data})

# When
with pytest.warns(DeprecationWarning):
created_project = collection.register('testing')

# Then
assert 1 == session.num_calls
expected_call = FakeCall(
method='POST',
path='/projects',
json={
'name': 'testing'
}
)
assert expected_call == session.last_call

assert 'A sample project' == created_project.description
assert 'CREATED' == created_project.status
assert create_time == created_project.created_at


def test_project_registration(collection: ProjectCollection, session):
# Given
create_time = parse('2019-09-10T00:00:00+00:00')
Expand Down Expand Up @@ -454,7 +424,7 @@ def test_list_no_team(session):
projects = list(project_collection.list())

assert 1 == session.num_calls
expected_call = FakeCall(method='GET', path=f'/projects', params={'per_page': 1000, 'page': 1})
expected_call = FakeCall(method='GET', path='/projects', params={'per_page': 1000, 'page': 1})
assert expected_call == session.last_call
assert 5 == len(projects)

Expand All @@ -472,6 +442,27 @@ def test_list_projects_with_page_params(collection, session):
expected_call = FakeCall(method='GET', path=f'/teams/{collection.team_id}/projects', params={'per_page': 10, 'page': 1})
assert expected_call == session.last_call

def test_search_all_no_team(session):
project_collection = ProjectCollection(session=session)
projects_data = ProjectDataFactory.create_batch(2)
project_name_to_match = projects_data[0]['name']

search_params = {
'name': {
'value': project_name_to_match,
'search_method': 'EXACT'}}
expected_response = [p for p in projects_data if p["name"] == project_name_to_match]

project_collection.session.set_response({'projects': expected_response})

# Then
results = list(project_collection.search_all(search_params=search_params))

expected_call = FakeCall(method='POST', path='/projects/search', params={'userId': ''}, json={'search_params': search_params})

assert 1 == project_collection.session.num_calls
assert expected_call == project_collection.session.last_call
assert 1 == len(results)

def test_search_all(collection: ProjectCollection):
# Given
Expand All @@ -490,7 +481,7 @@ def test_search_all(collection: ProjectCollection):
results = list(collection.search_all(search_params=search_params))

expected_call = FakeCall(method='POST',
path='/projects/search',
path=f'/teams/{collection.team_id}/projects/search',
params={'userId': ''},
json={'search_params': {
'name': {
Expand All @@ -513,7 +504,7 @@ def test_search_all_no_search_params(collection: ProjectCollection):
result = list(collection.search_all(search_params=None))

expected_call = FakeCall(method='POST',
path='/projects/search',
path=f'/teams/{collection.team_id}/projects/search',
params={'userId': ''},
json={})

Expand All @@ -539,7 +530,7 @@ def test_search_projects(collection: ProjectCollection):
result = list(collection.search(search_params=search_params))

expected_call = FakeCall(method='POST',
path='/projects/search',
path=f'/teams/{collection.team_id}/projects/search',
params={'userId': ''},
json={'search_params': {
'name': {
Expand All @@ -561,7 +552,7 @@ def test_search_projects_no_search_params(collection: ProjectCollection):
# Then
result = list(collection.search())

expected_call = FakeCall(method='POST', path='/projects/search', params={'userId': ''}, json={})
expected_call = FakeCall(method='POST', path=f'/teams/{collection.team_id}/projects/search', params={'userId': ''}, json={})

assert 1 == collection.session.num_calls
assert expected_call == collection.session.last_call
Expand All @@ -577,7 +568,7 @@ def test_delete_project(collection, session):

# Then
assert 1 == session.num_calls
expected_call = FakeCall(method='DELETE', path='/projects/{}'.format(uid))
expected_call = FakeCall(method='DELETE', path=f'/projects/{uid}')
assert expected_call == session.last_call


Expand Down Expand Up @@ -607,11 +598,8 @@ def test_list_members(project, session):

# Then
assert 2 == session.num_calls
expect_call_1 = FakeCall(
method='GET',
path='/teams/{}'.format(team_data['id']),
)
expect_call_2 = FakeCall(method='GET', path='/teams/{}/users'.format(project.team_id))
expect_call_1 = FakeCall(method='GET', path=f'/teams/{team_data["id"]}')
expect_call_2 = FakeCall(method='GET', path=f'/teams/{project.team_id}/users')
assert expect_call_1 == session.calls[0]
assert expect_call_2 == session.calls[1]
assert isinstance(members[0], TeamMember)
Expand Down
12 changes: 7 additions & 5 deletions tests/utils/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,13 @@ def __eq__(self, other) -> bool:
if not isinstance(other, FakeCall):
return NotImplemented

return self.method == other.method and \
self.path == other.path and \
self.json == other.json and \
self.params == other.params and \
(not self.version or not other.version or self.version == other.version) # Allows users to check the URL version without forcing everyone to.
return (
self.method == other.method and
self.path.lstrip('/') == other.path.lstrip('/') and # Leading slashes don't affect results
self.json == other.json and
self.params == other.params and
(not self.version or not other.version or self.version == other.version) # Allows users to check the URL version without forcing everyone to.
)


class FakeSession(Session):
Expand Down