From 4c1d6d90bc4326377ce670b74735029db9acde6a Mon Sep 17 00:00:00 2001 From: Yuanjun Huang Date: Thu, 28 Sep 2023 22:02:04 +0000 Subject: [PATCH] [auth] Be able to generate id_token Make auth be able to generate id_token. Some services on Cloud Run will need it (e.g. luci-config v2). Bug: 1487020 Change-Id: Icfe95002f93ee552b99ab2694c7b777e2322484b Reviewed-on: https://chromium-review.googlesource.com/c/chromium/tools/depot_tools/+/4899437 Reviewed-by: Yiwei Zhang Commit-Queue: Yuanjun Huang --- auth.py | 73 ++++++++++++++++++++++++++++------------ tests/auth_test.py | 83 +++++++++++++++++++++++++++++++++++++--------- 2 files changed, 118 insertions(+), 38 deletions(-) diff --git a/auth.py b/auth.py index 08aeeefca1..285fd6e969 100644 --- a/auth.py +++ b/auth.py @@ -29,14 +29,14 @@ def datetime_now(): return datetime.datetime.utcnow() -# OAuth access token with its expiration time (UTC datetime or None if unknown). -class AccessToken( - collections.namedtuple('AccessToken', [ - 'token', - 'expires_at', - ])): +# OAuth access token or ID token with its expiration time (UTC datetime or None +# if unknown). +class Token(collections.namedtuple('Token', [ + 'token', + 'expires_at', +])): def needs_refresh(self): - """True if this AccessToken should be refreshed.""" + """True if this token should be refreshed.""" if self.expires_at is not None: # Allow 30s of clock skew between client and backend. return datetime_now() + datetime.timedelta( @@ -67,22 +67,27 @@ def has_luci_context_local_auth(): class Authenticator(object): - """Object that knows how to refresh access tokens when needed. + """Object that knows how to refresh access tokens or id tokens when needed. Args: - scopes: space separated oauth scopes. Defaults to OAUTH_SCOPE_EMAIL. + scopes: space separated oauth scopes. It's used to generate access tokens. + Defaults to OAUTH_SCOPE_EMAIL. + audience: An audience in ID tokens to claim which clients should accept it. """ - def __init__(self, scopes=OAUTH_SCOPE_EMAIL): + def __init__(self, scopes=OAUTH_SCOPE_EMAIL, audience=None): self._access_token = None self._scopes = scopes + self._id_token = None + self._audience = audience def has_cached_credentials(self): """Returns True if credentials can be obtained. - If returns False, get_access_token() later will probably ask for interactive - login by raising LoginRequiredError. + If returns False, get_access_token() or get_id_token() later will probably + ask for interactive login by raising LoginRequiredError. - If returns True, get_access_token() won't ask for interactive login. + If returns True, get_access_token() or get_id_token() won't ask for + interactive login. """ return bool(self._get_luci_auth_token()) @@ -105,7 +110,27 @@ class Authenticator(object): logging.error('Failed to create access token') raise LoginRequiredError(self._scopes) - def authorize(self, http): + def get_id_token(self): + """Returns id token, refreshing it if necessary. + + Returns: + A Token object. + + Raises: + LoginRequiredError if user interaction is required. + """ + if self._id_token and not self._id_token.needs_refresh(): + return self._id_token + + self._id_token = self._get_luci_auth_token(use_id_token=True) + if self._id_token and not self._id_token.needs_refresh(): + return self._id_token + + # Nope, still expired. Needs user interaction. + logging.error('Failed to create id token') + raise LoginRequiredError() + + def authorize(self, http, use_id_token=False): """Monkey patches authentication logic of httplib2.Http instance. The modified http.request method will add authentication headers to each @@ -128,8 +153,9 @@ class Authenticator(object): redirections=httplib2.DEFAULT_MAX_REDIRECTS, connection_type=None): headers = (headers or {}).copy() - headers['Authorization'] = 'Bearer %s' % self.get_access_token( - ).token + auth_token = self.get_access_token( + ) if not use_id_token else self.get_id_token() + headers['Authorization'] = 'Bearer %s' % auth_token.token return request_orig(uri, method, body, headers, redirections, connection_type) @@ -148,18 +174,21 @@ class Authenticator(object): subprocess2.check_call(['luci-auth', 'login', '-scopes', self._scopes]) return self._get_luci_auth_token() - def _get_luci_auth_token(self): + def _get_luci_auth_token(self, use_id_token=False): logging.debug('Running luci-auth token') + if use_id_token: + args = ['-use-id-token'] + ['-audience', self._audience + ] if self._audience else [] + else: + args = ['-scopes', self._scopes] try: - out, err = subprocess2.check_call_out([ - 'luci-auth', 'token', '-scopes', self._scopes, '-json-output', - '-' - ], + out, err = subprocess2.check_call_out(['luci-auth', 'token'] + + args + ['-json-output', '-'], stdout=subprocess2.PIPE, stderr=subprocess2.PIPE) logging.debug('luci-auth token stderr:\n%s', err) token_info = json.loads(out) - return AccessToken( + return Token( token_info['token'], datetime.datetime.utcfromtimestamp(token_info['expiry'])) except subprocess2.CalledProcessError as e: diff --git a/tests/auth_test.py b/tests/auth_test.py index 3886e8a4d6..ec80c32647 100755 --- a/tests/auth_test.py +++ b/tests/auth_test.py @@ -52,8 +52,8 @@ class AuthenticatorTest(unittest.TestCase): def testGetAccessToken_CachedToken(self): authenticator = auth.Authenticator() - authenticator._access_token = auth.AccessToken('token', None) - self.assertEqual(auth.AccessToken('token', None), + authenticator._access_token = auth.Token('token', None) + self.assertEqual(auth.Token('token', None), authenticator.get_access_token()) subprocess2.check_call_out.assert_not_called() @@ -63,7 +63,7 @@ class AuthenticatorTest(unittest.TestCase): 'token': 'token', 'expiry': expiry }), '') - self.assertEqual(auth.AccessToken('token', VALID_EXPIRY), + self.assertEqual(auth.Token('token', VALID_EXPIRY), auth.Authenticator().get_access_token()) subprocess2.check_call_out.assert_called_with([ 'luci-auth', 'token', '-scopes', auth.OAUTH_SCOPE_EMAIL, @@ -78,7 +78,7 @@ class AuthenticatorTest(unittest.TestCase): 'token': 'token', 'expiry': expiry }), '') - self.assertEqual(auth.AccessToken('token', VALID_EXPIRY), + self.assertEqual(auth.Token('token', VALID_EXPIRY), auth.Authenticator('custom scopes').get_access_token()) subprocess2.check_call_out.assert_called_with([ 'luci-auth', 'token', '-scopes', 'custom scopes', '-json-output', @@ -87,41 +87,92 @@ class AuthenticatorTest(unittest.TestCase): stdout=subprocess2.PIPE, stderr=subprocess2.PIPE) - def testAuthorize(self): + def testAuthorize_AccessToken(self): http = mock.Mock() http_request = http.request http_request.__name__ = '__name__' authenticator = auth.Authenticator() - authenticator._access_token = auth.AccessToken('token', None) + authenticator._access_token = auth.Token('access_token', None) + authenticator._id_token = auth.Token('id_token', None) authorized = authenticator.authorize(http) authorized.request('https://example.com', method='POST', body='body', headers={'header': 'value'}) - http_request.assert_called_once_with('https://example.com', 'POST', - 'body', { - 'header': 'value', - 'Authorization': 'Bearer token' - }, mock.ANY, mock.ANY) + http_request.assert_called_once_with( + 'https://example.com', 'POST', 'body', { + 'header': 'value', + 'Authorization': 'Bearer access_token' + }, mock.ANY, mock.ANY) + + def testGetIdToken_NotLoggedIn(self): + subprocess2.check_call_out.side_effect = [ + subprocess2.CalledProcessError(1, ['cmd'], 'cwd', 'stdout', + 'stderr') + ] + self.assertRaises(auth.LoginRequiredError, + auth.Authenticator().get_id_token) + + def testGetIdToken_CachedToken(self): + authenticator = auth.Authenticator() + authenticator._id_token = auth.Token('token', None) + self.assertEqual(auth.Token('token', None), + authenticator.get_id_token()) + subprocess2.check_call_out.assert_not_called() + + def testGetIdToken_LoggedIn(self): + expiry = calendar.timegm(VALID_EXPIRY.timetuple()) + subprocess2.check_call_out.return_value = (json.dumps({ + 'token': 'token', + 'expiry': expiry + }), '') + self.assertEqual( + auth.Token('token', VALID_EXPIRY), + auth.Authenticator(audience='https://test.com').get_id_token()) + subprocess2.check_call_out.assert_called_with([ + 'luci-auth', 'token', '-use-id-token', '-audience', + 'https://test.com', '-json-output', '-' + ], + stdout=subprocess2.PIPE, + stderr=subprocess2.PIPE) + + def testAuthorize_IdToken(self): + http = mock.Mock() + http_request = http.request + http_request.__name__ = '__name__' + + authenticator = auth.Authenticator() + authenticator._access_token = auth.Token('access_token', None) + authenticator._id_token = auth.Token('id_token', None) + + authorized = authenticator.authorize(http, use_id_token=True) + authorized.request('https://example.com', + method='POST', + body='body', + headers={'header': 'value'}) + http_request.assert_called_once_with( + 'https://example.com', 'POST', 'body', { + 'header': 'value', + 'Authorization': 'Bearer id_token' + }, mock.ANY, mock.ANY) -class AccessTokenTest(unittest.TestCase): +class TokenTest(unittest.TestCase): def setUp(self): mock.patch('auth.datetime_now', return_value=NOW).start() self.addCleanup(mock.patch.stopall) def testNeedsRefresh_NoExpiry(self): - self.assertFalse(auth.AccessToken('token', None).needs_refresh()) + self.assertFalse(auth.Token('token', None).needs_refresh()) def testNeedsRefresh_Expired(self): expired = NOW + datetime.timedelta(seconds=30) - self.assertTrue(auth.AccessToken('token', expired).needs_refresh()) + self.assertTrue(auth.Token('token', expired).needs_refresh()) def testNeedsRefresh_Valid(self): - self.assertFalse( - auth.AccessToken('token', VALID_EXPIRY).needs_refresh()) + self.assertFalse(auth.Token('token', VALID_EXPIRY).needs_refresh()) class HasLuciContextLocalAuthTest(unittest.TestCase):