[gerrit_util] Change Authenticator API to return proxy info.

This will be used with an upcoming SSOAuthenticator implementation
which will need to proxy all http requests for Googlers.

R=ayatane, gavinmak@google.com

Bug: 336351842
Change-Id: If8cbb8db51fce198e704f109232868421130b40c
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/tools/depot_tools/+/5582100
Commit-Queue: Gavin Mak <gavinmak@google.com>
Auto-Submit: Robbie Iannucci <iannucci@chromium.org>
Reviewed-by: Gavin Mak <gavinmak@google.com>
changes/00/5582100/2
Robert Iannucci 1 year ago committed by LUCI CQ
parent 11ed5e0222
commit c57b7ed364

@ -9,7 +9,7 @@ https://gerrit-review.googlesource.com/Documentation/rest-api.html
import base64 import base64
import contextlib import contextlib
from typing import List, Type from typing import List, Optional, Tuple, Type
import httplib2 import httplib2
import json import json
import logging import logging
@ -96,7 +96,17 @@ def _QueryString(params, first_param=None):
class Authenticator(object): class Authenticator(object):
"""Base authenticator class for authenticator implementations to subclass.""" """Base authenticator class for authenticator implementations to subclass."""
def get_auth_header(self, host): def get_auth_info(self, host: str) -> Tuple[Optional[str], Optional[httplib2.ProxyInfo]]:
"""Returns the Authorization header value, plus an optional ProxyInfo.
TODO: Remove `host`. This is only needed for the deprecated
CookiesAuthenticator. If distinguishing between hosts is still needed
later, I would propose moving this parameter to
Authenticator.get/Authenticator.is_applicable/Authenticator.__init__
instead.
TODO: Make auth header non-optional.
"""
raise NotImplementedError() raise NotImplementedError()
def debug_summary_state(self) -> str: def debug_summary_state(self) -> str:
@ -231,16 +241,16 @@ class CookiesAuthenticator(Authenticator):
return (creds[0], None, creds[1]) return (creds[0], None, creds[1])
return None return None
def get_auth_header(self, host): def get_auth_info(self, host: str) -> Tuple[Optional[str], Optional[httplib2.ProxyInfo]]:
a = self._get_auth_for_host(host) a = self._get_auth_for_host(host)
if a: if a:
if a[0]: if a[0]:
secret = base64.b64encode( secret = base64.b64encode(
('%s:%s' % (a[0], a[2])).encode('utf-8')) ('%s:%s' % (a[0], a[2])).encode('utf-8'))
return 'Basic %s' % secret.decode('utf-8') return 'Basic %s' % secret.decode('utf-8'), None
return 'Bearer %s' % a[2] return 'Bearer %s' % a[2], None
return None return None, None
# Used to redact the cookies from the gitcookies file. # Used to redact the cookies from the gitcookies file.
GITCOOKIES_REDACT_RE = re.compile(r'1/.*') GITCOOKIES_REDACT_RE = re.compile(r'1/.*')
@ -333,11 +343,11 @@ class GceAuthenticator(Authenticator):
cls._token_expiration = cls._token_cache['expires_in'] + time_time() cls._token_expiration = cls._token_cache['expires_in'] + time_time()
return cls._token_cache return cls._token_cache
def get_auth_header(self, _host): def get_auth_info(self, host: str) -> Tuple[Optional[str], Optional[httplib2.ProxyInfo]]:
token_dict = self._get_token_dict() token_dict = self._get_token_dict()
if not token_dict: if not token_dict:
return None return None, None
return '%(token_type)s %(access_token)s' % token_dict return '%(token_type)s %(access_token)s' % token_dict, None
def debug_summary_state(self) -> str: def debug_summary_state(self) -> str:
# TODO(b/343230702) - report ambient account name. # TODO(b/343230702) - report ambient account name.
@ -355,8 +365,8 @@ class LuciContextAuthenticator(Authenticator):
self._authenticator = auth.Authenticator(' '.join( self._authenticator = auth.Authenticator(' '.join(
[auth.OAUTH_SCOPE_EMAIL, auth.OAUTH_SCOPE_GERRIT])) [auth.OAUTH_SCOPE_EMAIL, auth.OAUTH_SCOPE_GERRIT]))
def get_auth_header(self, _host): def get_auth_info(self, host: str) -> Tuple[Optional[str], Optional[httplib2.ProxyInfo]]:
return 'Bearer %s' % self._authenticator.get_access_token().token return 'Bearer %s' % self._authenticator.get_access_token().token, None
def debug_summary_state(self) -> str: def debug_summary_state(self) -> str:
# TODO(b/343230702) - report ambient account name. # TODO(b/343230702) - report ambient account name.
@ -373,22 +383,22 @@ def CreateHttpConn(host,
headers = headers or {} headers = headers or {}
bare_host = host.partition(':')[0] bare_host = host.partition(':')[0]
a = Authenticator.get() authenticator = Authenticator.get()
# TODO(crbug.com/1059384): Automatically detect when running on cloudtop. # TODO(crbug.com/1059384): Automatically detect when running on cloudtop.
if isinstance(a, GceAuthenticator): if isinstance(authenticator, GceAuthenticator):
print('If you\'re on a cloudtop instance, export ' print('If you\'re on a cloudtop instance, export '
'SKIP_GCE_AUTH_FOR_GIT=1 in your env.') 'SKIP_GCE_AUTH_FOR_GIT=1 in your env.')
a = a.get_auth_header(bare_host) auth_header, proxy = authenticator.get_auth_info(bare_host)
if a: if auth_header:
headers.setdefault('Authorization', a) headers.setdefault('Authorization', auth_header)
else: else:
LOGGER.debug('No authorization found for %s.' % bare_host) LOGGER.debug('No authorization found for %s.' % bare_host)
url = path url = path
if not url.startswith('/'): if not url.startswith('/'):
url = '/' + url url = '/' + url
if 'Authorization' in headers and not url.startswith('/a/'): if auth_header and not url.startswith('/a/'):
url = '/a%s' % url url = '/a%s' % url
if body: if body:
@ -402,7 +412,7 @@ def CreateHttpConn(host,
LOGGER.debug('%s: %s' % (key, val)) LOGGER.debug('%s: %s' % (key, val))
if body: if body:
LOGGER.debug(body) LOGGER.debug(body)
conn = httplib2.Http(timeout=timeout) conn = httplib2.Http(timeout=timeout, proxy_info=proxy)
# HACK: httplib2.Http has no such attribute; we store req_host here for # HACK: httplib2.Http has no such attribute; we store req_host here for
# later use in ReadHttpResponse. # later use in ReadHttpResponse.
conn.req_host = host conn.req_host = host

@ -2312,12 +2312,12 @@ class Changelist(object):
git_host = self._GetGitHost() git_host = self._GetGitHost()
assert self._gerrit_server and self._gerrit_host and git_host assert self._gerrit_server and self._gerrit_host and git_host
gerrit_auth = cookie_auth.get_auth_header(self._gerrit_host) gerrit_auth, _ = cookie_auth.get_auth_info(self._gerrit_host)
git_auth = cookie_auth.get_auth_header(git_host) git_auth, _ = cookie_auth.get_auth_info(git_host)
if gerrit_auth and git_auth: if gerrit_auth and git_auth:
if gerrit_auth == git_auth: if gerrit_auth == git_auth:
return return
all_gsrc = cookie_auth.get_auth_header( all_gsrc, _ = cookie_auth.get_auth_info(
'd0esN0tEx1st.googlesource.com') 'd0esN0tEx1st.googlesource.com')
print( print(
'WARNING: You have different credentials for Gerrit and git hosts:\n' 'WARNING: You have different credentials for Gerrit and git hosts:\n'

@ -151,13 +151,13 @@ class CookiesAuthenticatorTest(unittest.TestCase):
'Basic Z2l0LXVzZXIuY2hyb21pdW0ub3JnOjEvY2hyb21pdW0tc2VjcmV0') 'Basic Z2l0LXVzZXIuY2hyb21pdW0ub3JnOjEvY2hyb21pdW0tc2VjcmV0')
auth = gerrit_util.CookiesAuthenticator() auth = gerrit_util.CookiesAuthenticator()
self.assertEqual(expected_chromium_header, self.assertEqual((expected_chromium_header, None),
auth.get_auth_header('chromium.googlesource.com')) auth.get_auth_info('chromium.googlesource.com'))
self.assertEqual( self.assertEqual(
expected_chromium_header, (expected_chromium_header, None),
auth.get_auth_header('chromium-review.googlesource.com')) auth.get_auth_info('chromium-review.googlesource.com'))
self.assertEqual('Bearer example-bearer-token', self.assertEqual(('Bearer example-bearer-token', None),
auth.get_auth_header('some-review.example.com')) auth.get_auth_info('some-review.example.com'))
def testGetAuthEmail(self): def testGetAuthEmail(self):
auth = gerrit_util.CookiesAuthenticator() auth = gerrit_util.CookiesAuthenticator()
@ -226,15 +226,21 @@ class GceAuthenticatorTest(unittest.TestCase):
def testGetAuthHeader_Error(self): def testGetAuthHeader_Error(self):
httplib2.Http().request.side_effect = httplib2.HttpLib2Error httplib2.Http().request.side_effect = httplib2.HttpLib2Error
self.assertIsNone(self.GceAuthenticator().get_auth_header('')) self.assertEqual(
(None, None),
self.GceAuthenticator().get_auth_info(''))
def testGetAuthHeader_500(self): def testGetAuthHeader_500(self):
httplib2.Http().request.return_value = (mock.Mock(status=500), None) httplib2.Http().request.return_value = (mock.Mock(status=500), None)
self.assertIsNone(self.GceAuthenticator().get_auth_header('')) self.assertEqual(
(None, None),
self.GceAuthenticator().get_auth_info(''))
def testGetAuthHeader_Non200(self): def testGetAuthHeader_Non200(self):
httplib2.Http().request.return_value = (mock.Mock(status=403), None) httplib2.Http().request.return_value = (mock.Mock(status=403), None)
self.assertIsNone(self.GceAuthenticator().get_auth_header('')) self.assertEqual(
(None, None),
self.GceAuthenticator().get_auth_info(''))
def testGetAuthHeader_OK(self): def testGetAuthHeader_OK(self):
httplib2.Http().request.return_value = ( httplib2.Http().request.return_value = (
@ -242,8 +248,8 @@ class GceAuthenticatorTest(unittest.TestCase):
'{"expires_in": 125, "token_type": "TYPE", "access_token": "TOKEN"}' '{"expires_in": 125, "token_type": "TYPE", "access_token": "TOKEN"}'
) )
gerrit_util.time_time.return_value = 0 gerrit_util.time_time.return_value = 0
self.assertEqual('TYPE TOKEN', self.assertEqual(('TYPE TOKEN', None),
self.GceAuthenticator().get_auth_header('')) self.GceAuthenticator().get_auth_info(''))
def testGetAuthHeader_Cache(self): def testGetAuthHeader_Cache(self):
httplib2.Http().request.return_value = ( httplib2.Http().request.return_value = (
@ -251,10 +257,10 @@ class GceAuthenticatorTest(unittest.TestCase):
'{"expires_in": 125, "token_type": "TYPE", "access_token": "TOKEN"}' '{"expires_in": 125, "token_type": "TYPE", "access_token": "TOKEN"}'
) )
gerrit_util.time_time.return_value = 0 gerrit_util.time_time.return_value = 0
self.assertEqual('TYPE TOKEN', self.assertEqual(('TYPE TOKEN', None),
self.GceAuthenticator().get_auth_header('')) self.GceAuthenticator().get_auth_info(''))
self.assertEqual('TYPE TOKEN', self.assertEqual(('TYPE TOKEN', None),
self.GceAuthenticator().get_auth_header('')) self.GceAuthenticator().get_auth_info(''))
httplib2.Http().request.assert_called_once() httplib2.Http().request.assert_called_once()
def testGetAuthHeader_CacheOld(self): def testGetAuthHeader_CacheOld(self):
@ -263,10 +269,10 @@ class GceAuthenticatorTest(unittest.TestCase):
'{"expires_in": 125, "token_type": "TYPE", "access_token": "TOKEN"}' '{"expires_in": 125, "token_type": "TYPE", "access_token": "TOKEN"}'
) )
gerrit_util.time_time.side_effect = [0, 100, 200] gerrit_util.time_time.side_effect = [0, 100, 200]
self.assertEqual('TYPE TOKEN', self.assertEqual(('TYPE TOKEN', None),
self.GceAuthenticator().get_auth_header('')) self.GceAuthenticator().get_auth_info(''))
self.assertEqual('TYPE TOKEN', self.assertEqual(('TYPE TOKEN', None),
self.GceAuthenticator().get_auth_header('')) self.GceAuthenticator().get_auth_info(''))
self.assertEqual(2, len(httplib2.Http().request.mock_calls)) self.assertEqual(2, len(httplib2.Http().request.mock_calls))
@ -294,7 +300,7 @@ class GerritUtilTest(unittest.TestCase):
@mock.patch('gerrit_util.Authenticator') @mock.patch('gerrit_util.Authenticator')
def testCreateHttpConn_Basic(self, mockAuth): def testCreateHttpConn_Basic(self, mockAuth):
mockAuth.get().get_auth_header.return_value = None mockAuth.get().get_auth_info.return_value = None, None
conn = gerrit_util.CreateHttpConn('host.example.com', 'foo/bar') conn = gerrit_util.CreateHttpConn('host.example.com', 'foo/bar')
self.assertEqual('host.example.com', conn.req_host) self.assertEqual('host.example.com', conn.req_host)
self.assertEqual( self.assertEqual(
@ -307,7 +313,7 @@ class GerritUtilTest(unittest.TestCase):
@mock.patch('gerrit_util.Authenticator') @mock.patch('gerrit_util.Authenticator')
def testCreateHttpConn_Authenticated(self, mockAuth): def testCreateHttpConn_Authenticated(self, mockAuth):
mockAuth.get().get_auth_header.return_value = 'Bearer token' mockAuth.get().get_auth_info.return_value = 'Bearer token', None
conn = gerrit_util.CreateHttpConn('host.example.com', conn = gerrit_util.CreateHttpConn('host.example.com',
'foo/bar', 'foo/bar',
headers={'header': 'value'}) headers={'header': 'value'})
@ -325,7 +331,7 @@ class GerritUtilTest(unittest.TestCase):
@mock.patch('gerrit_util.Authenticator') @mock.patch('gerrit_util.Authenticator')
def testCreateHttpConn_Body(self, mockAuth): def testCreateHttpConn_Body(self, mockAuth):
mockAuth.get().get_auth_header.return_value = None mockAuth.get().get_auth_info.return_value = None, None
conn = gerrit_util.CreateHttpConn('host.example.com', conn = gerrit_util.CreateHttpConn('host.example.com',
'foo/bar', 'foo/bar',
body={ body={

@ -2598,7 +2598,7 @@ class TestGitCl(unittest.TestCase):
'chromium-review.googlesource.com': ('', None, 'secret'), 'chromium-review.googlesource.com': ('', None, 'secret'),
}) })
self.assertIsNone(cl.EnsureAuthenticated(force=False)) self.assertIsNone(cl.EnsureAuthenticated(force=False))
header = gerrit_util.CookiesAuthenticator().get_auth_header( header, _ = gerrit_util.CookiesAuthenticator().get_auth_info(
'chromium.googlesource.com') 'chromium.googlesource.com')
self.assertTrue('Bearer' in header) self.assertTrue('Bearer' in header)

Loading…
Cancel
Save