[gerrit_util] Factor out SSOHelper

Add a layer of abstraction/isolation for general organization.

Also, this logic needs to be used in Git setup too, not just Gerrit
authentication.

Bug: b/348024314
Change-Id: Ie1310a9b8e71c05c72a4b987dcbff76b70c67945
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/tools/depot_tools/+/5645906
Commit-Queue: Allen Li <ayatane@chromium.org>
Reviewed-by: Yiwei Zhang <yiwzhang@google.com>
changes/06/5645906/5
Allen Li 1 year ago committed by LUCI CQ
parent c4c3d5326e
commit bdf64705c3

@ -11,7 +11,6 @@ from __future__ import annotations
import base64 import base64
import contextlib import contextlib
import functools
import http.cookiejar import http.cookiejar
import json import json
import logging import logging
@ -149,6 +148,29 @@ def _QueryString(params, first_param=None):
return '+'.join(q) return '+'.join(q)
class SSOHelper(object):
"""SSOHelper finds a Google-internal SSO helper."""
_sso_cmd: Optional[str] = None
def find_cmd(self) -> str:
"""Returns the cached command-line to invoke git-remote-sso.
If git-remote-sso is not in $PATH, returns None.
"""
if self._sso_cmd is not None:
return self._sso_cmd
cmd = shutil.which('git-remote-sso')
if cmd is None:
cmd = ''
self._sso_cmd = cmd
return cmd
# Global instance
ssoHelper = SSOHelper()
class Authenticator(object): class Authenticator(object):
"""Base authenticator class for authenticator implementations to subclass.""" """Base authenticator class for authenticator implementations to subclass."""
@ -234,16 +256,6 @@ class SSOAuthenticator(Authenticator):
# Overridden in tests. # Overridden in tests.
_timeout_secs = 5 _timeout_secs = 5
# Tri-state cache for sso helper command:
# * None - no lookup yet
# * () - lookup was performed, but no binary was found.
# * non-empty tuple - lookup was performed, and this is the command to
# run.
#
# NOTE: Tests directly assign to this to substitute a helper process that
# can exercise other aspects of SSOAuthenticator.
_sso_cmd: Optional[Tuple[str, ...]] = None
@dataclass @dataclass
class SSOInfo: class SSOInfo:
proxy: httplib2.ProxyInfo proxy: httplib2.ProxyInfo
@ -260,19 +272,14 @@ class SSOAuthenticator(Authenticator):
If git-remote-sso is not in $PATH, returns (). If git-remote-sso is not in $PATH, returns ().
""" """
cmd = cls._sso_cmd cmd = ssoHelper.find_cmd()
if cmd is None: if not cmd:
pth = shutil.which('git-remote-sso') return ()
if pth is None: return (
cmd = () cmd,
else: '-print_config',
cmd = ( 'sso://*.git.corp.google.com',
pth, )
'-print_config',
'sso://*.git.corp.google.com',
)
cls._sso_cmd = cmd
return cmd
@classmethod @classmethod
def is_applicable(cls) -> bool: def is_applicable(cls) -> bool:

@ -599,7 +599,6 @@ class SSOAuthenticatorTest(unittest.TestCase):
return super().setUpClass() return super().setUpClass()
def setUp(self) -> None: def setUp(self) -> None:
gerrit_util.SSOAuthenticator._sso_cmd = None
gerrit_util.SSOAuthenticator._sso_info = None gerrit_util.SSOAuthenticator._sso_info = None
gerrit_util.SSOAuthenticator._testing_load_expired_cookies = True gerrit_util.SSOAuthenticator._testing_load_expired_cookies = True
gerrit_util.SSOAuthenticator._timeout_secs = self._original_timeout_secs gerrit_util.SSOAuthenticator._timeout_secs = self._original_timeout_secs
@ -607,7 +606,6 @@ class SSOAuthenticatorTest(unittest.TestCase):
return super().setUp() return super().setUp()
def tearDown(self) -> None: def tearDown(self) -> None:
gerrit_util.SSOAuthenticator._sso_cmd = None
gerrit_util.SSOAuthenticator._sso_info = None gerrit_util.SSOAuthenticator._sso_info = None
gerrit_util.SSOAuthenticator._testing_load_expired_cookies = False gerrit_util.SSOAuthenticator._testing_load_expired_cookies = False
gerrit_util.SSOAuthenticator._timeout_secs = self._original_timeout_secs gerrit_util.SSOAuthenticator._timeout_secs = self._original_timeout_secs
@ -619,24 +617,19 @@ class SSOAuthenticatorTest(unittest.TestCase):
# Here _testMethodName would be a string like "testCmdAssemblyFound" # Here _testMethodName would be a string like "testCmdAssemblyFound"
return base / self._testMethodName return base / self._testMethodName
@mock.patch('shutil.which', return_value='/fake/git-remote-sso') @mock.patch('gerrit_util.ssoHelper.find_cmd',
return_value='/fake/git-remote-sso')
def testCmdAssemblyFound(self, _): def testCmdAssemblyFound(self, _):
self.assertEqual(self.sso._resolve_sso_cmd(), self.assertEqual(self.sso._resolve_sso_cmd(),
('/fake/git-remote-sso', '-print_config', ('/fake/git-remote-sso', '-print_config',
'sso://*.git.corp.google.com')) 'sso://*.git.corp.google.com'))
self.assertTrue(self.sso.is_applicable()) self.assertTrue(self.sso.is_applicable())
@mock.patch('shutil.which', return_value=None) @mock.patch('gerrit_util.ssoHelper.find_cmd', return_value=None)
def testCmdAssemblyNotFound(self, _): def testCmdAssemblyNotFound(self, _):
self.assertEqual(self.sso._resolve_sso_cmd(), ()) self.assertEqual(self.sso._resolve_sso_cmd(), ())
self.assertFalse(self.sso.is_applicable()) self.assertFalse(self.sso.is_applicable())
@mock.patch('shutil.which', return_value='/fake/git-remote-sso')
def testCmdAssemblyCached(self, which):
self.sso._resolve_sso_cmd()
self.sso._resolve_sso_cmd()
self.assertEqual(which.called, 1)
def testParseConfigOK(self): def testParseConfigOK(self):
parsed = self.sso._parse_config( parsed = self.sso._parse_config(
textwrap.dedent(f''' textwrap.dedent(f'''
@ -696,5 +689,26 @@ class SSOAuthenticatorTest(unittest.TestCase):
self.sso._get_sso_info() self.sso._get_sso_info()
class SSOHelperTest(unittest.TestCase):
def setUp(self) -> None:
self.sso = gerrit_util.SSOHelper()
return super().setUp()
@mock.patch('shutil.which', return_value='/fake/git-remote-sso')
def testFindCmd(self, _):
self.assertEqual(self.sso.find_cmd(), '/fake/git-remote-sso')
@mock.patch('shutil.which', return_value=None)
def testFindCmdMissing(self, _):
self.assertEqual(self.sso.find_cmd(), '')
@mock.patch('shutil.which', return_value='/fake/git-remote-sso')
def testFindCmdCached(self, which):
self.sso.find_cmd()
self.sso.find_cmd()
self.assertEqual(which.called, 1)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

Loading…
Cancel
Save