Reland "gerrit_util: Refactor ReadHttpResponse and add more tests."

This is a reland of 5bfa3ae88d

Replace cStringIO with StringIO and add tests.

Original change's description:
> gerrit_util: Refactor ReadHttpResponse and add more tests.
>
> Bug: 1016601
> Change-Id: Ie6afc5b1ea29888b0bf40bdb39b2b492d2d0494c
> Reviewed-on: https://chromium-review.googlesource.com/c/chromium/tools/depot_tools/+/1880014
> Reviewed-by: Anthony Polito <apolito@google.com>
> Commit-Queue: Edward Lesmes <ehmaldonado@chromium.org>

Bug: 1016601
Change-Id: I0c83a83202169b6a1acc60bdf6f4cd00eac6e2a6
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/tools/depot_tools/+/1884461
Reviewed-by: Anthony Polito <apolito@google.com>
Commit-Queue: Edward Lesmes <ehmaldonado@chromium.org>
changes/61/1884461/4
Edward Lemur 6 years ago committed by Commit Bot
parent 27eb01c355
commit 4ba192e7a9

@ -1 +1,23 @@
python_version: "3.8" python_version: "3.8"
# Used by:
# auth.py
# gerrit_util.py
# git_cl.py
# my_activity.py
# TODO(crbug.com/1002153): Add ninjalog_uploader.py
wheel: <
name: "infra/python/wheels/httplib2-py3"
version: "version:0.13.1"
>
# Used by:
# my_activity.py
wheel: <
name: "infra/python/wheels/python-dateutil-py2_py3"
version: "version:2.7.3"
>
wheel: <
name: "infra/python/wheels/six-py2_py3"
version: "version:1.10.0"
>

@ -9,11 +9,10 @@ https://gerrit-review.googlesource.com/Documentation/rest-api.html
""" """
from __future__ import print_function from __future__ import print_function
from __future__ import unicode_literals
import base64 import base64
import contextlib import contextlib
import cookielib
import httplib # Still used for its constants.
import httplib2 import httplib2
import json import json
import logging import logging
@ -26,9 +25,6 @@ import stat
import sys import sys
import tempfile import tempfile
import time import time
import urllib
import urlparse
from cStringIO import StringIO
from multiprocessing.pool import ThreadPool from multiprocessing.pool import ThreadPool
import auth import auth
@ -37,6 +33,16 @@ import metrics
import metrics_utils import metrics_utils
import subprocess2 import subprocess2
from third_party import six
from six.moves import urllib
if sys.version_info.major == 2:
import cookielib
from StringIO import StringIO
else:
import http.cookiejar as cookielib
from io import StringIO
LOGGER = logging.getLogger() LOGGER = logging.getLogger()
# With a starting sleep time of 1.5 seconds, 2^n exponential backoff, and seven # With a starting sleep time of 1.5 seconds, 2^n exponential backoff, and seven
# total tries, the sleep time between the first and last tries will be 94.5 sec. # total tries, the sleep time between the first and last tries will be 94.5 sec.
@ -48,16 +54,18 @@ TRY_LIMIT = 3
GERRIT_PROTOCOL = 'https' GERRIT_PROTOCOL = 'https'
def time_sleep(seconds):
# Use this so that it can be mocked in tests without interfering with python
# system machinery.
return time.sleep(seconds)
class GerritError(Exception): class GerritError(Exception):
"""Exception class for errors commuicating with the gerrit-on-borg service.""" """Exception class for errors commuicating with the gerrit-on-borg service."""
def __init__(self, http_status, *args, **kwargs): def __init__(self, http_status, message, *args, **kwargs):
super(GerritError, self).__init__(*args, **kwargs) super(GerritError, self).__init__(*args, **kwargs)
self.http_status = http_status self.http_status = http_status
self.message = '(%d) %s' % (self.http_status, self.message) self.message = '(%d) %s' % (self.http_status, message)
class GerritAuthenticationError(GerritError):
"""Exception class for authentication errors during Gerrit communication."""
def _QueryString(params, first_param=None): def _QueryString(params, first_param=None):
@ -65,21 +73,11 @@ def _QueryString(params, first_param=None):
https://gerrit-review.googlesource.com/Documentation/rest-api-changes.html#list-changes https://gerrit-review.googlesource.com/Documentation/rest-api-changes.html#list-changes
""" """
q = [urllib.quote(first_param)] if first_param else [] q = [urllib.parse.quote(first_param)] if first_param else []
q.extend(['%s:%s' % (key, val) for key, val in params]) q.extend(['%s:%s' % (key, val) for key, val in params])
return '+'.join(q) return '+'.join(q)
def GetConnectionObject(protocol=None):
if protocol is None:
protocol = GERRIT_PROTOCOL
if protocol in ('http', 'https'):
return httplib2.Http()
else:
raise RuntimeError(
"Don't know how to work with protocol '%s'" % protocol)
class Authenticator(object): class Authenticator(object):
"""Base authenticator class for authenticator implementations to subclass.""" """Base authenticator class for authenticator implementations to subclass."""
@ -300,11 +298,13 @@ class GceAuthenticator(Authenticator):
def _get(url, **kwargs): def _get(url, **kwargs):
next_delay_sec = 1 next_delay_sec = 1
for i in xrange(TRY_LIMIT): for i in xrange(TRY_LIMIT):
p = urlparse.urlparse(url) p = urllib.parse.urlparse(url)
c = GetConnectionObject(protocol=p.scheme) if p.scheme not in ('http', 'https'):
resp, contents = c.request(url, 'GET', **kwargs) raise RuntimeError(
"Don't know how to work with protocol '%s'" % protocol)
resp, contents = httplib2.Http().request(url, 'GET', **kwargs)
LOGGER.debug('GET [%s] #%d/%d (%d)', url, i+1, TRY_LIMIT, resp.status) LOGGER.debug('GET [%s] #%d/%d (%d)', url, i+1, TRY_LIMIT, resp.status)
if resp.status < httplib.INTERNAL_SERVER_ERROR: if resp.status < 500:
return (resp, contents) return (resp, contents)
# Retry server error status codes. # Retry server error status codes.
@ -312,7 +312,7 @@ class GceAuthenticator(Authenticator):
if TRY_LIMIT - i > 1: if TRY_LIMIT - i > 1:
LOGGER.info('Will retry in %d seconds (%d more times)...', LOGGER.info('Will retry in %d seconds (%d more times)...',
next_delay_sec, TRY_LIMIT - i - 1) next_delay_sec, TRY_LIMIT - i - 1)
time.sleep(next_delay_sec) time_sleep(next_delay_sec)
next_delay_sec *= 2 next_delay_sec *= 2
@classmethod @classmethod
@ -323,7 +323,7 @@ class GceAuthenticator(Authenticator):
return cls._token_cache return cls._token_cache
resp, contents = cls._get(cls._ACQUIRE_URL, headers=cls._ACQUIRE_HEADERS) resp, contents = cls._get(cls._ACQUIRE_URL, headers=cls._ACQUIRE_HEADERS)
if resp.status != httplib.OK: if resp.status != 200:
return None return None
cls._token_cache = json.loads(contents) cls._token_cache = json.loads(contents)
cls._token_expiration = cls._token_cache['expires_in'] + time.time() cls._token_expiration = cls._token_cache['expires_in'] + time.time()
@ -370,7 +370,7 @@ def CreateHttpConn(host, path, reqtype='GET', headers=None, body=None):
url = '/a%s' % url url = '/a%s' % url
if body: if body:
body = json.JSONEncoder().encode(body) body = json.dumps(body, sort_keys=True)
headers.setdefault('Content-Type', 'application/json') headers.setdefault('Content-Type', 'application/json')
if LOGGER.isEnabledFor(logging.DEBUG): if LOGGER.isEnabledFor(logging.DEBUG):
LOGGER.debug('%s %s://%s%s' % (reqtype, GERRIT_PROTOCOL, host, url)) LOGGER.debug('%s %s://%s%s' % (reqtype, GERRIT_PROTOCOL, host, url))
@ -380,12 +380,12 @@ def CreateHttpConn(host, path, reqtype='GET', headers=None, body=None):
LOGGER.debug('%s: %s' % (key, val)) LOGGER.debug('%s: %s' % (key, val))
if body: if body:
LOGGER.debug(body) LOGGER.debug(body)
conn = GetConnectionObject() conn = httplib2.Http()
# HACK: httplib.Http has no such attribute; we store req_host here for later # HACK: httplib2.Http has no such attribute; we store req_host here for later
# use in ReadHttpResponse. # use in ReadHttpResponse.
conn.req_host = host conn.req_host = host
conn.req_params = { conn.req_params = {
'uri': urlparse.urljoin('%s://%s' % (GERRIT_PROTOCOL, host), url), 'uri': urllib.parse.urljoin('%s://%s' % (GERRIT_PROTOCOL, host), url),
'method': reqtype, 'method': reqtype,
'headers': headers, 'headers': headers,
'body': body, 'body': body,
@ -406,6 +406,7 @@ def ReadHttpResponse(conn, accept_statuses=frozenset([200])):
for idx in range(TRY_LIMIT): for idx in range(TRY_LIMIT):
before_response = time.time() before_response = time.time()
response, contents = conn.request(**conn.req_params) response, contents = conn.request(**conn.req_params)
contents = contents.decode('utf-8', 'replace')
response_time = time.time() - before_response response_time = time.time() - before_response
metrics.collector.add_repeated( metrics.collector.add_repeated(
@ -414,21 +415,12 @@ def ReadHttpResponse(conn, accept_statuses=frozenset([200])):
conn.req_params['uri'], conn.req_params['method'], response.status, conn.req_params['uri'], conn.req_params['method'], response.status,
response_time)) response_time))
# Check if this is an authentication issue. # If response.status is an accepted status,
www_authenticate = response.get('www-authenticate') # or response.status < 500 then the result is final; break retry loop.
if (response.status in (httplib.UNAUTHORIZED, httplib.FOUND) and # If the response is 404/409 it might be because of replication lag,
www_authenticate): # so keep trying anyway.
auth_match = re.search('realm="([^"]+)"', www_authenticate, re.I) if (response.status in accept_statuses
host = auth_match.group(1) if auth_match else conn.req_host or response.status < 500 and response.status not in [404, 409]):
reason = ('Authentication failed. Please make sure your .gitcookies file '
'has credentials for %s' % host)
raise GerritAuthenticationError(response.status, reason)
# If response.status < 500 then the result is final; break retry loop.
# If the response is 404/409, it might be because of replication lag, so
# keep trying anyway.
if ((response.status < 500 and response.status not in [404, 409])
or response.status in accept_statuses):
LOGGER.debug('got response %d for %s %s', response.status, LOGGER.debug('got response %d for %s %s', response.status,
conn.req_params['method'], conn.req_params['uri']) conn.req_params['method'], conn.req_params['uri'])
# If 404 was in accept_statuses, then it's expected that the file might # If 404 was in accept_statuses, then it's expected that the file might
@ -437,6 +429,7 @@ def ReadHttpResponse(conn, accept_statuses=frozenset([200])):
if response.status == 404: if response.status == 404:
contents = '' contents = ''
break break
# A status >=500 is assumed to be a possible transient error; retry. # A status >=500 is assumed to be a possible transient error; retry.
http_version = 'HTTP/%s' % ('1.1' if response.version == 11 else '1.0') http_version = 'HTTP/%s' % ('1.1' if response.version == 11 else '1.0')
LOGGER.warn('A transient error occurred while querying %s:\n' LOGGER.warn('A transient error occurred while querying %s:\n'
@ -446,19 +439,29 @@ def ReadHttpResponse(conn, accept_statuses=frozenset([200])):
conn.req_params['uri'], conn.req_params['uri'],
http_version, http_version, response.status, response.reason) http_version, http_version, response.status, response.reason)
if TRY_LIMIT - idx > 1: if idx < TRY_LIMIT - 1:
LOGGER.info('Will retry in %d seconds (%d more times)...', LOGGER.info('Will retry in %d seconds (%d more times)...',
sleep_time, TRY_LIMIT - idx - 1) sleep_time, TRY_LIMIT - idx - 1)
time.sleep(sleep_time) time_sleep(sleep_time)
sleep_time = sleep_time * 2 sleep_time = sleep_time * 2
# end of retries loop # end of retries loop
if response.status not in accept_statuses:
if response.status in (401, 403): if response.status in accept_statuses:
print('Your Gerrit credentials might be misconfigured. Try: \n' return StringIO(contents)
' git cl creds-check')
if response.status in (302, 401, 403):
www_authenticate = response.get('www-authenticate')
if not www_authenticate:
print('Your Gerrit credentials might be misconfigured.')
else:
auth_match = re.search('realm="([^"]+)"', www_authenticate, re.I)
host = auth_match.group(1) if auth_match else conn.req_host
print('Authentication failed. Please make sure your .gitcookies '
'file has credentials for %s.' % host)
print('Try:\n git cl creds-check')
reason = '%s: %s' % (response.reason, contents) reason = '%s: %s' % (response.reason, contents)
raise GerritError(response.status, reason) raise GerritError(response.status, reason)
return StringIO(contents)
def ReadHttpJsonResponse(conn, accept_statuses=frozenset([200])): def ReadHttpJsonResponse(conn, accept_statuses=frozenset([200])):
@ -575,7 +578,7 @@ def MultiQueryChanges(host, params, change_list, limit=None, o_params=None,
if not change_list: if not change_list:
raise RuntimeError( raise RuntimeError(
"MultiQueryChanges requires a list of change numbers/id's") "MultiQueryChanges requires a list of change numbers/id's")
q = ['q=%s' % '+OR+'.join([urllib.quote(str(x)) for x in change_list])] q = ['q=%s' % '+OR+'.join([urllib.parse.quote(str(x)) for x in change_list])]
if params: if params:
q.append(_QueryString(params)) q.append(_QueryString(params))
if limit: if limit:
@ -601,7 +604,8 @@ def GetGerritFetchUrl(host):
def GetCodeReviewTbrScore(host, project): def GetCodeReviewTbrScore(host, project):
"""Given a Gerrit host name and project, return the Code-Review score for TBR. """Given a Gerrit host name and project, return the Code-Review score for TBR.
""" """
conn = CreateHttpConn(host, '/projects/%s' % urllib.quote(project, safe='')) conn = CreateHttpConn(
host, '/projects/%s' % urllib.parse.quote(project, ''))
project = ReadHttpJsonResponse(conn) project = ReadHttpJsonResponse(conn)
if ('labels' not in project if ('labels' not in project
or 'Code-Review' not in project['labels'] or 'Code-Review' not in project['labels']
@ -988,4 +992,4 @@ def ChangeIdentifier(project, change_number):
comparing to specifying just change_number. comparing to specifying just change_number.
""" """
assert int(change_number) assert int(change_number)
return '%s~%s' % (urllib.quote(project, safe=''), change_number) return '%s~%s' % (urllib.parse.quote(project, ''), change_number)

@ -1,4 +1,4 @@
#!/usr/bin/env vpython #!/usr/bin/env vpython3
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019 The Chromium Authors. All rights reserved. # Copyright (c) 2019 The Chromium Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be # Use of this source code is governed by a BSD-style license that can be
@ -9,6 +9,7 @@ from __future__ import unicode_literals
import base64 import base64
import json
import os import os
import sys import sys
import unittest import unittest
@ -19,8 +20,15 @@ from third_party import mock
import gerrit_util import gerrit_util
import gclient_utils import gclient_utils
import metrics
import metrics_utils
import subprocess2 import subprocess2
if sys.version_info.major == 2:
from cStringIO import StringIO
else:
from io import StringIO
class CookiesAuthenticatorTest(unittest.TestCase): class CookiesAuthenticatorTest(unittest.TestCase):
_GITCOOKIES = '\n'.join([ _GITCOOKIES = '\n'.join([
@ -167,5 +175,214 @@ class CookiesAuthenticatorTest(unittest.TestCase):
self.assertIsNone(auth.get_auth_email('some-review.example.com')) self.assertIsNone(auth.get_auth_email('some-review.example.com'))
class GerritUtilTest(unittest.TestCase):
def setUp(self):
super(GerritUtilTest, self).setUp()
mock.patch('gerrit_util.LOGGER').start()
mock.patch('gerrit_util.time_sleep').start()
mock.patch('metrics.collector').start()
mock.patch(
'metrics_utils.extract_http_metrics',
return_value='http_metrics').start()
self.addCleanup(mock.patch.stopall)
def testQueryString(self):
self.assertEqual('', gerrit_util._QueryString([]))
self.assertEqual(
'first%20param%2B', gerrit_util._QueryString([], 'first param+'))
self.assertEqual(
'key:val+foo:bar',
gerrit_util._QueryString([('key', 'val'), ('foo', 'bar')]))
self.assertEqual(
'first%20param%2B+key:val+foo:bar',
gerrit_util._QueryString(
[('key', 'val'), ('foo', 'bar')], 'first param+'))
@mock.patch('gerrit_util.Authenticator')
def testCreateHttpConn_Basic(self, mockAuth):
mockAuth.get().get_auth_header.return_value = None
conn = gerrit_util.CreateHttpConn('host.example.com', 'foo/bar')
self.assertEqual('host.example.com', conn.req_host)
self.assertEqual({
'uri': 'https://host.example.com/foo/bar',
'method': 'GET',
'headers': {},
'body': None,
}, conn.req_params)
@mock.patch('gerrit_util.Authenticator')
def testCreateHttpConn_Authenticated(self, mockAuth):
mockAuth.get().get_auth_header.return_value = 'Bearer token'
conn = gerrit_util.CreateHttpConn(
'host.example.com', 'foo/bar', headers={'header': 'value'})
self.assertEqual('host.example.com', conn.req_host)
self.assertEqual({
'uri': 'https://host.example.com/a/foo/bar',
'method': 'GET',
'headers': {'Authorization': 'Bearer token', 'header': 'value'},
'body': None,
}, conn.req_params)
@mock.patch('gerrit_util.Authenticator')
def testCreateHttpConn_Body(self, mockAuth):
mockAuth.get().get_auth_header.return_value = None
conn = gerrit_util.CreateHttpConn(
'host.example.com', 'foo/bar', body={'l': [1, 2, 3], 'd': {'k': 'v'}})
self.assertEqual('host.example.com', conn.req_host)
self.assertEqual({
'uri': 'https://host.example.com/foo/bar',
'method': 'GET',
'headers': {'Content-Type': 'application/json'},
'body': '{"d": {"k": "v"}, "l": [1, 2, 3]}',
}, conn.req_params)
def testReadHttpResponse_200(self):
conn = mock.Mock()
conn.req_params = {'uri': 'uri', 'method': 'method'}
conn.request.return_value = (mock.Mock(status=200), b'content\xe2\x9c\x94')
content = gerrit_util.ReadHttpResponse(conn)
self.assertEqual('content✔', content.getvalue())
metrics.collector.add_repeated.assert_called_once_with(
'http_requests', 'http_metrics')
def testReadHttpResponse_AuthenticationIssue(self):
for status in (302, 401, 403):
response = mock.Mock(status=status)
response.get.return_value = None
conn = mock.Mock(req_params={'uri': 'uri', 'method': 'method'})
conn.request.return_value = (response, b'')
with mock.patch('sys.stdout', StringIO()):
with self.assertRaises(gerrit_util.GerritError) as cm:
gerrit_util.ReadHttpResponse(conn)
self.assertEqual(status, cm.exception.http_status)
self.assertIn(
'Your Gerrit credentials might be misconfigured',
sys.stdout.getvalue())
def testReadHttpResponse_ClientError(self):
conn = mock.Mock(req_params={'uri': 'uri', 'method': 'method'})
conn.request.return_value = (mock.Mock(status=404), b'')
with self.assertRaises(gerrit_util.GerritError) as cm:
gerrit_util.ReadHttpResponse(conn)
self.assertEqual(404, cm.exception.http_status)
def testReadHttpResponse_ServerError(self):
conn = mock.Mock(req_params={'uri': 'uri', 'method': 'method'})
conn.request.return_value = (mock.Mock(status=500), b'')
with self.assertRaises(gerrit_util.GerritError) as cm:
gerrit_util.ReadHttpResponse(conn)
self.assertEqual(500, cm.exception.http_status)
self.assertEqual(gerrit_util.TRY_LIMIT, len(conn.request.mock_calls))
self.assertEqual(
[mock.call(1.5), mock.call(3)], gerrit_util.time_sleep.mock_calls)
def testReadHttpResponse_ServerErrorAndSuccess(self):
conn = mock.Mock(req_params={'uri': 'uri', 'method': 'method'})
conn.request.side_effect = [
(mock.Mock(status=500), b''),
(mock.Mock(status=200), b'content\xe2\x9c\x94'),
]
self.assertEqual('content✔', gerrit_util.ReadHttpResponse(conn).getvalue())
self.assertEqual(2, len(conn.request.mock_calls))
gerrit_util.time_sleep.assert_called_once_with(1.5)
def testReadHttpResponse_Expected404(self):
conn = mock.Mock()
conn.req_params = {'uri': 'uri', 'method': 'method'}
conn.request.return_value = (mock.Mock(status=404), b'content\xe2\x9c\x94')
content = gerrit_util.ReadHttpResponse(conn, (404,))
self.assertEqual('', content.getvalue())
@mock.patch('gerrit_util.ReadHttpResponse')
def testReadHttpJsonResponse_NotJSON(self, mockReadHttpResponse):
mockReadHttpResponse.return_value = StringIO('not json')
with self.assertRaises(gerrit_util.GerritError) as cm:
gerrit_util.ReadHttpJsonResponse(None)
self.assertEqual(cm.exception.http_status, 200)
self.assertEqual(
cm.exception.message, '(200) Unexpected json output: not json')
@mock.patch('gerrit_util.ReadHttpResponse')
def testReadHttpJsonResponse_EmptyValue(self, mockReadHttpResponse):
mockReadHttpResponse.return_value = StringIO(')]}\'')
self.assertIsNone(gerrit_util.ReadHttpJsonResponse(None))
@mock.patch('gerrit_util.ReadHttpResponse')
def testReadHttpJsonResponse_JSON(self, mockReadHttpResponse):
expected_value = {'foo': 'bar', 'baz': [1, '2', 3]}
mockReadHttpResponse.return_value = StringIO(
')]}\'\n' + json.dumps(expected_value))
self.assertEqual(expected_value, gerrit_util.ReadHttpJsonResponse(None))
@mock.patch('gerrit_util.CreateHttpConn')
@mock.patch('gerrit_util.ReadHttpJsonResponse')
def testQueryChanges(self, mockJsonResponse, mockCreateHttpConn):
gerrit_util.QueryChanges(
'host', [('key', 'val'), ('foo', 'bar')], 'first param', limit=500,
o_params=['PARAM_A', 'PARAM_B'], start='start')
mockCreateHttpConn.assert_called_once_with(
'host',
('changes/?q=first%20param+key:val+foo:bar'
'&start=start'
'&n=500'
'&o=PARAM_A'
'&o=PARAM_B'))
def testQueryChanges_NoParams(self):
self.assertRaises(RuntimeError, gerrit_util.QueryChanges, 'host', [])
@mock.patch('gerrit_util.QueryChanges')
def testGenerateAllChanges(self, mockQueryChanges):
mockQueryChanges.side_effect = [
# First results page
[
{'_number': '4'},
{'_number': '3'},
{'_number': '2', '_more_changes': True},
],
# Second results page, there are new changes, so second page includes
# some results from the first page.
[
{'_number': '2'},
{'_number': '1'},
],
# GenerateAllChanges queries again from the start to get any new
# changes (5 in this case).
[
{'_number': '5'},
{'_number': '4'},
{'_number': '3', '_more_changes': True},
],
]
changes = list(gerrit_util.GenerateAllChanges('host', 'params'))
self.assertEqual(
[
{'_number': '4'},
{'_number': '3'},
{'_number': '2', '_more_changes': True},
{'_number': '1'},
{'_number': '5'},
],
changes)
self.assertEqual(
[
mock.call('host', 'params', None, 500, None, 0),
mock.call('host', 'params', None, 500, None, 3),
mock.call('host', 'params', None, 500, None, 0),
],
mockQueryChanges.mock_calls)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

Loading…
Cancel
Save