diff --git a/google/auth/jwt.py b/google/auth/jwt.py index 412f122e3..b1eb5fb91 100644 --- a/google/auth/jwt.py +++ b/google/auth/jwt.py @@ -46,13 +46,17 @@ import datetime import json +import cachetools +from six.moves import urllib + from google.auth import _helpers from google.auth import _service_account_info from google.auth import crypt +from google.auth import exceptions import google.auth.credentials - _DEFAULT_TOKEN_LIFETIME_SECS = 3600 # 1 hour in seconds +_DEFAULT_MAX_CACHE_SIZE = 10 def encode(signer, payload, header=None, key_id=None): @@ -316,10 +320,10 @@ def __init__(self, signer, issuer, subject, audience, self._audience = audience self._token_lifetime = token_lifetime - if additional_claims is not None: - self._additional_claims = additional_claims - else: - self._additional_claims = {} + if additional_claims is None: + additional_claims = {} + + self._additional_claims = additional_claims @classmethod def _from_signer_and_info(cls, signer, info, **kwargs): @@ -343,8 +347,7 @@ def _from_signer_and_info(cls, signer, info, **kwargs): @classmethod def from_service_account_info(cls, info, **kwargs): - """Creates a Credentials instance from a dictionary containing service - account info in Google format. + """Creates an Credentials instance from a dictionary. Args: info (Mapping[str, str]): The service account info in Google @@ -487,3 +490,266 @@ def signer_email(self): @_helpers.copy_docstring(google.auth.credentials.Signing) def signer(self): return self._signer + + +class OnDemandCredentials( + google.auth.credentials.Signing, + google.auth.credentials.Credentials): + """On-demand JWT credentials. + + Like :class:`Credentials`, this class uses a JWT as the bearer token for + authentication. However, this class does not require the audience at + construction time. Instead, it will generate a new token on-demand for + each request using the request URI as the audience. It caches tokens + so that multiple requests to the same URI do not incur the overhead + of generating a new token every time. + + This behavior is especially useful for `gRPC`_ clients. A gRPC service may + have multiple audience and gRPC clients may not know all of the audiences + required for accessing a particular service. With these credentials, + no knowledge of the audiences is required ahead of time. + + .. _grpc: http://www.grpc.io/ + """ + + def __init__(self, signer, issuer, subject, + additional_claims=None, + token_lifetime=_DEFAULT_TOKEN_LIFETIME_SECS, + max_cache_size=_DEFAULT_MAX_CACHE_SIZE): + """ + Args: + signer (google.auth.crypt.Signer): The signer used to sign JWTs. + issuer (str): The `iss` claim. + subject (str): The `sub` claim. + additional_claims (Mapping[str, str]): Any additional claims for + the JWT payload. + token_lifetime (int): The amount of time in seconds for + which the token is valid. Defaults to 1 hour. + max_cache_size (int): The maximum number of JWT tokens to keep in + cache. Tokens are cached using :class:`cachetools.LRUCache`. + """ + super(OnDemandCredentials, self).__init__() + self._signer = signer + self._issuer = issuer + self._subject = subject + self._token_lifetime = token_lifetime + + if additional_claims is None: + additional_claims = {} + + self._additional_claims = additional_claims + self._cache = cachetools.LRUCache(maxsize=max_cache_size) + + @classmethod + def _from_signer_and_info(cls, signer, info, **kwargs): + """Creates an OnDemandCredentials instance from a signer and service + account info. + + Args: + signer (google.auth.crypt.Signer): The signer used to sign JWTs. + info (Mapping[str, str]): The service account info. + kwargs: Additional arguments to pass to the constructor. + + Returns: + google.auth.jwt.OnDemandCredentials: The constructed credentials. + + Raises: + ValueError: If the info is not in the expected format. + """ + kwargs.setdefault('subject', info['client_email']) + kwargs.setdefault('issuer', info['client_email']) + return cls(signer, **kwargs) + + @classmethod + def from_service_account_info(cls, info, **kwargs): + """Creates an OnDemandCredentials instance from a dictionary. + + Args: + info (Mapping[str, str]): The service account info in Google + format. + kwargs: Additional arguments to pass to the constructor. + + Returns: + google.auth.jwt.OnDemandCredentials: The constructed credentials. + + Raises: + ValueError: If the info is not in the expected format. + """ + signer = _service_account_info.from_dict( + info, require=['client_email']) + return cls._from_signer_and_info(signer, info, **kwargs) + + @classmethod + def from_service_account_file(cls, filename, **kwargs): + """Creates an OnDemandCredentials instance from a service account .json + file in Google format. + + Args: + filename (str): The path to the service account .json file. + kwargs: Additional arguments to pass to the constructor. + + Returns: + google.auth.jwt.OnDemandCredentials: The constructed credentials. + """ + info, signer = _service_account_info.from_filename( + filename, require=['client_email']) + return cls._from_signer_and_info(signer, info, **kwargs) + + @classmethod + def from_signing_credentials(cls, credentials, **kwargs): + """Creates a new :class:`google.auth.jwt.OnDemandCredentials` instance + from an existing :class:`google.auth.credentials.Signing` instance. + + The new instance will use the same signer as the existing instance and + will use the existing instance's signer email as the issuer and + subject by default. + + Example:: + + svc_creds = service_account.Credentials.from_service_account_file( + 'service_account.json') + jwt_creds = jwt.OnDemandCredentials.from_signing_credentials( + svc_creds) + + Args: + credentials (google.auth.credentials.Signing): The credentials to + use to construct the new credentials. + kwargs: Additional arguments to pass to the constructor. + + Returns: + google.auth.jwt.Credentials: A new Credentials instance. + """ + kwargs.setdefault('issuer', credentials.signer_email) + kwargs.setdefault('subject', credentials.signer_email) + return cls(credentials.signer, **kwargs) + + def with_claims(self, issuer=None, subject=None, additional_claims=None): + """Returns a copy of these credentials with modified claims. + + Args: + issuer (str): The `iss` claim. If unspecified the current issuer + claim will be used. + subject (str): The `sub` claim. If unspecified the current subject + claim will be used. + additional_claims (Mapping[str, str]): Any additional claims for + the JWT payload. This will be merged with the current + additional claims. + + Returns: + google.auth.jwt.OnDemandCredentials: A new credentials instance. + """ + new_additional_claims = copy.deepcopy(self._additional_claims) + new_additional_claims.update(additional_claims or {}) + + return OnDemandCredentials( + self._signer, + issuer=issuer if issuer is not None else self._issuer, + subject=subject if subject is not None else self._subject, + additional_claims=new_additional_claims, + max_cache_size=self._cache.maxsize) + + @property + def valid(self): + """Checks the validity of the credentials. + + These credentials are always valid because it generates tokens on + demand. + """ + return True + + def _make_jwt_for_audience(self, audience): + """Make a new JWT for the given audience. + + Args: + audience (str): The intended audience. + + Returns: + Tuple[bytes, datetime]: The encoded JWT and the expiration. + """ + now = _helpers.utcnow() + lifetime = datetime.timedelta(seconds=self._token_lifetime) + expiry = now + lifetime + + payload = { + 'iss': self._issuer, + 'sub': self._subject, + 'iat': _helpers.datetime_to_secs(now), + 'exp': _helpers.datetime_to_secs(expiry), + 'aud': audience, + } + + payload.update(self._additional_claims) + + jwt = encode(self._signer, payload) + + return jwt, expiry + + def _get_jwt_for_audience(self, audience): + """Get a JWT For a given audience. + + If there is already an existing, non-expired token in the cache for + the audience, that token is used. Otherwise, a new token will be + created. + + Args: + audience (str): The intended audience. + + Returns: + bytes: The encoded JWT. + """ + token, expiry = self._cache.get(audience, (None, None)) + + if token is None or expiry < _helpers.utcnow(): + token, expiry = self._make_jwt_for_audience(audience) + self._cache[audience] = token, expiry + + return token + + def refresh(self, request): + """Raises an exception, these credentials can not be directly + refreshed. + + Args: + request (Any): Unused. + + Raises: + google.auth.RefreshError + """ + # pylint: disable=unused-argument + # (pylint doesn't correctly recognize overridden methods.) + raise exceptions.RefreshError( + 'OnDemandCredentials can not be directly refreshed.') + + def before_request(self, request, method, url, headers): + """Performs credential-specific before request logic. + + Args: + request (Any): Unused. JWT credentials do not need to make an + HTTP request to refresh. + method (str): The request's HTTP method. + url (str): The request's URI. This is used as the audience claim + when generating the JWT. + headers (Mapping): The request's headers. + """ + # pylint: disable=unused-argument + # (pylint doesn't correctly recognize overridden methods.) + parts = urllib.parse.urlsplit(url) + # Strip query string and fragment + audience = urllib.parse.urlunsplit( + (parts.scheme, parts.netloc, parts.path, None, None)) + token = self._get_jwt_for_audience(audience) + self.apply(headers, token=token) + + @_helpers.copy_docstring(google.auth.credentials.Signing) + def sign_bytes(self, message): + return self._signer.sign(message) + + @property + @_helpers.copy_docstring(google.auth.credentials.Signing) + def signer_email(self): + return self._issuer + + @property + @_helpers.copy_docstring(google.auth.credentials.Signing) + def signer(self): + return self._signer diff --git a/setup.py b/setup.py index aaa13de4a..bad634a6a 100644 --- a/setup.py +++ b/setup.py @@ -23,6 +23,7 @@ 'pyasn1-modules>=0.0.5', 'rsa>=3.1.4', 'six>=1.9.0', + 'cachetools>=2.0.0', ) diff --git a/system_tests/test_grpc.py b/system_tests/test_grpc.py index 4bf1c5ba5..365bc91d3 100644 --- a/system_tests/test_grpc.py +++ b/system_tests/test_grpc.py @@ -39,7 +39,7 @@ def test_grpc_request_with_regular_credentials(http_request): list(list_topics_iter) -def test_grpc_request_with_jwt_credentials(http_request): +def test_grpc_request_with_jwt_credentials(): credentials, project_id = google.auth.default() audience = 'https://{}/google.pubsub.v1.Publisher'.format( publisher_client.PublisherClient.SERVICE_ADDRESS) @@ -49,7 +49,27 @@ def test_grpc_request_with_jwt_credentials(http_request): channel = google.auth.transport.grpc.secure_authorized_channel( credentials, - http_request, + None, + publisher_client.PublisherClient.SERVICE_ADDRESS) + + # Create a pub/sub client. + client = publisher_client.PublisherClient(channel=channel) + + # list the topics and drain the iterator to test that an authorized API + # call works. + list_topics_iter = client.list_topics( + project='projects/{}'.format(project_id)) + list(list_topics_iter) + + +def test_grpc_request_with_on_demand_jwt_credentials(): + credentials, project_id = google.auth.default() + credentials = google.auth.jwt.OnDemandCredentials.from_signing_credentials( + credentials) + + channel = google.auth.transport.grpc.secure_authorized_channel( + credentials, + None, publisher_client.PublisherClient.SERVICE_ADDRESS) # Create a pub/sub client. diff --git a/tests/test_jwt.py b/tests/test_jwt.py index 59769de2e..22c5bc538 100644 --- a/tests/test_jwt.py +++ b/tests/test_jwt.py @@ -22,6 +22,7 @@ from google.auth import _helpers from google.auth import crypt +from google.auth import exceptions from google.auth import jwt @@ -196,7 +197,7 @@ def test_roundtrip_explicit_key_id(token_factory): assert payload['user'] == 'billy bob' -class TestCredentials: +class TestCredentials(object): SERVICE_ACCOUNT_EMAIL = 'service-account@example.com' SUBJECT = 'subject' AUDIENCE = 'audience' @@ -343,3 +344,135 @@ def test_before_request_refreshes(self): self.credentials.before_request( None, 'GET', 'http://example.com?a=1#3', {}) assert self.credentials.valid + + +class TestOnDemandCredentials(object): + SERVICE_ACCOUNT_EMAIL = 'service-account@example.com' + SUBJECT = 'subject' + ADDITIONAL_CLAIMS = {'meta': 'data'} + credentials = None + + @pytest.fixture(autouse=True) + def credentials_fixture(self, signer): + self.credentials = jwt.OnDemandCredentials( + signer, self.SERVICE_ACCOUNT_EMAIL, self.SERVICE_ACCOUNT_EMAIL, + max_cache_size=2) + + def test_from_service_account_info(self): + with open(SERVICE_ACCOUNT_JSON_FILE, 'r') as fh: + info = json.load(fh) + + credentials = jwt.OnDemandCredentials.from_service_account_info(info) + + assert credentials._signer.key_id == info['private_key_id'] + assert credentials._issuer == info['client_email'] + assert credentials._subject == info['client_email'] + + def test_from_service_account_info_args(self): + info = SERVICE_ACCOUNT_INFO.copy() + + credentials = jwt.OnDemandCredentials.from_service_account_info( + info, subject=self.SUBJECT, + additional_claims=self.ADDITIONAL_CLAIMS) + + assert credentials._signer.key_id == info['private_key_id'] + assert credentials._issuer == info['client_email'] + assert credentials._subject == self.SUBJECT + assert credentials._additional_claims == self.ADDITIONAL_CLAIMS + + def test_from_service_account_file(self): + info = SERVICE_ACCOUNT_INFO.copy() + + credentials = jwt.OnDemandCredentials.from_service_account_file( + SERVICE_ACCOUNT_JSON_FILE) + + assert credentials._signer.key_id == info['private_key_id'] + assert credentials._issuer == info['client_email'] + assert credentials._subject == info['client_email'] + + def test_from_service_account_file_args(self): + info = SERVICE_ACCOUNT_INFO.copy() + + credentials = jwt.OnDemandCredentials.from_service_account_file( + SERVICE_ACCOUNT_JSON_FILE, subject=self.SUBJECT, + additional_claims=self.ADDITIONAL_CLAIMS) + + assert credentials._signer.key_id == info['private_key_id'] + assert credentials._issuer == info['client_email'] + assert credentials._subject == self.SUBJECT + assert credentials._additional_claims == self.ADDITIONAL_CLAIMS + + def test_from_signing_credentials(self): + jwt_from_signing = self.credentials.from_signing_credentials( + self.credentials) + jwt_from_info = jwt.OnDemandCredentials.from_service_account_info( + SERVICE_ACCOUNT_INFO) + + assert isinstance(jwt_from_signing, jwt.OnDemandCredentials) + assert jwt_from_signing._signer.key_id == jwt_from_info._signer.key_id + assert jwt_from_signing._issuer == jwt_from_info._issuer + assert jwt_from_signing._subject == jwt_from_info._subject + + def test_default_state(self): + # Credentials are *always* valid. + assert self.credentials.valid + # Credentials *never* expire. + assert not self.credentials.expired + + def test_with_claims(self): + new_claims = {'meep': 'moop'} + new_credentials = self.credentials.with_claims( + additional_claims=new_claims) + + assert new_credentials._signer == self.credentials._signer + assert new_credentials._issuer == self.credentials._issuer + assert new_credentials._subject == self.credentials._subject + assert new_credentials._additional_claims == new_claims + + def test_sign_bytes(self): + to_sign = b'123' + signature = self.credentials.sign_bytes(to_sign) + assert crypt.verify_signature(to_sign, signature, PUBLIC_CERT_BYTES) + + def test_signer(self): + assert isinstance(self.credentials.signer, crypt.RSASigner) + + def test_signer_email(self): + assert (self.credentials.signer_email == + SERVICE_ACCOUNT_INFO['client_email']) + + def _verify_token(self, token): + payload = jwt.decode(token, PUBLIC_CERT_BYTES) + assert payload['iss'] == self.SERVICE_ACCOUNT_EMAIL + return payload + + def test_refresh(self): + with pytest.raises(exceptions.RefreshError): + self.credentials.refresh(None) + + def test_before_request(self): + headers = {} + + self.credentials.before_request( + None, 'GET', 'http://example.com?a=1#3', headers) + + _, token = headers['authorization'].split(' ') + payload = self._verify_token(token) + + assert payload['aud'] == 'http://example.com' + + # Making another request should re-use the same token. + self.credentials.before_request( + None, 'GET', 'http://example.com?b=2', headers) + + _, new_token = headers['authorization'].split(' ') + + assert new_token == token + + def test_expired_token(self): + self.credentials._cache['audience'] = ( + mock.sentinel.token, datetime.datetime.min) + + token = self.credentials._get_jwt_for_audience('audience') + + assert token != mock.sentinel.token