diff --git a/rietveld.py b/rietveld.py index 9adc9c5ca..22168f572 100644 --- a/rietveld.py +++ b/rietveld.py @@ -21,6 +21,7 @@ import logging import re import socket import ssl +import StringIO import sys import time import urllib @@ -409,23 +410,20 @@ class Rietveld(object): if m: # Fake an HTTPError exception. Cheezy. :( raise urllib2.HTTPError( - request_path, int(m.group(1)), msg, None, None) + request_path, int(m.group(1)), msg, None, StringIO.StringIO()) old_error_exit(msg) upload.ErrorExit = trap_http_500 for retry in xrange(self._maxtries): try: logging.debug('%s' % request_path) - result = self.rpc_server.Send(request_path, **kwargs) - # Sometimes GAE returns a HTTP 200 but with HTTP 500 as the content. - # How nice. - return result + return self.rpc_server.Send(request_path, **kwargs) except urllib2.HTTPError, e: if retry >= (self._maxtries - 1): raise - flake_codes = [500, 502, 503] + flake_codes = {500, 502, 503} if retry_on_404: - flake_codes.append(404) + flake_codes.add(404) if e.code not in flake_codes: raise except urllib2.URLError, e: @@ -440,10 +438,10 @@ class Rietveld(object): # The reason can be a string or another exception, e.g., # socket.error or whatever else. reason_as_str = str(e.reason) - for retry_anyway in [ + for retry_anyway in ( 'Name or service not known', 'EOF occurred in violation of protocol', - 'timed out']: + 'timed out'): if retry_anyway in reason_as_str: return True return False # Assume permanent otherwise. @@ -528,6 +526,11 @@ class OAuthRpcServer(object): payload: request is a POST if not None, GET otherwise timeout: in seconds extra_headers: (dict) + + Returns: the HTTP response body as a string + + Raises: + urllib2.HTTPError """ # This method signature should match upload.py:AbstractRpcServer.Send() method = 'GET' @@ -543,7 +546,6 @@ class OAuthRpcServer(object): try: if timeout: self._http.timeout = timeout - # TODO(pgervais) implement some kind of retry mechanism (see upload.py). url = self.host + request_path if kwargs: url += "?" + urllib.urlencode(kwargs) @@ -572,6 +574,11 @@ class OAuthRpcServer(object): continue break + if ret[0].status >= 300: + raise urllib2.HTTPError( + request_path, int(ret[0]['status']), ret[1], None, + StringIO.StringIO()) + return ret[1] finally: diff --git a/tests/rietveld_test.py b/tests/rietveld_test.py index 7bcb9bcbc..880d17e70 100755 --- a/tests/rietveld_test.py +++ b/tests/rietveld_test.py @@ -5,19 +5,24 @@ """Unit tests for rietveld.py.""" +import httplib import logging import os import socket import ssl +import StringIO import sys +import tempfile import time import traceback import unittest +import urllib2 sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from testing_support.patches_data import GIT, RAW from testing_support import auto_stub +from third_party import httplib2 import patch import rietveld @@ -490,6 +495,104 @@ class DefaultTimeoutTest(auto_stub.TestCase): self.rietveld.post('/api/1234', [('key', 'data')]) self.assertNotEqual(self.sleep_time, 0) + +class OAuthRpcServerTest(auto_stub.TestCase): + def setUp(self): + super(OAuthRpcServerTest, self).setUp() + self.rpc_server = rietveld.OAuthRpcServer( + 'http://www.example.com', 'foo', 'bar') + + def set_mock_response(self, status): + def mock_http_request(*args, **kwargs): + return (httplib2.Response({'status': status}), 'body') + self.mock(self.rpc_server._http, 'request', mock_http_request) + + def test_404(self): + self.set_mock_response(404) + with self.assertRaises(urllib2.HTTPError) as ctx: + self.rpc_server.Send('/foo') + self.assertEquals(404, ctx.exception.code) + + def test_200(self): + self.set_mock_response(200) + ret = self.rpc_server.Send('/foo') + self.assertEquals('body', ret) + + +class RietveldOAuthRpcServerTest(auto_stub.TestCase): + def setUp(self): + super(RietveldOAuthRpcServerTest, self).setUp() + with tempfile.NamedTemporaryFile() as private_key_file: + self.rietveld = rietveld.JwtOAuth2Rietveld( + 'http://www.example.com', 'foo', private_key_file.name, maxtries=2) + + self.mock(time, 'sleep', lambda duration: None) + + def test_retries_500(self): + urls = [] + def mock_http_request(url, *args, **kwargs): + urls.append(url) + return (httplib2.Response({'status': 500}), 'body') + self.mock(self.rietveld.rpc_server._http, 'request', mock_http_request) + + with self.assertRaises(urllib2.HTTPError) as ctx: + self.rietveld.get('/foo') + self.assertEquals(500, ctx.exception.code) + + self.assertEqual(2, len(urls)) # maxtries was 2 + self.assertEqual(['https://www.example.com/foo'] * 2, urls) + + def test_does_not_retry_404(self): + urls = [] + def mock_http_request(url, *args, **kwargs): + urls.append(url) + return (httplib2.Response({'status': 404}), 'body') + self.mock(self.rietveld.rpc_server._http, 'request', mock_http_request) + + with self.assertRaises(urllib2.HTTPError) as ctx: + self.rietveld.get('/foo') + self.assertEquals(404, ctx.exception.code) + + self.assertEqual(1, len(urls)) # doesn't retry + + def test_retries_404_when_requested(self): + urls = [] + def mock_http_request(url, *args, **kwargs): + urls.append(url) + return (httplib2.Response({'status': 404}), 'body') + self.mock(self.rietveld.rpc_server._http, 'request', mock_http_request) + + with self.assertRaises(urllib2.HTTPError) as ctx: + self.rietveld.get('/foo', retry_on_404=True) + self.assertEquals(404, ctx.exception.code) + + self.assertEqual(2, len(urls)) # maxtries was 2 + + def test_socket_timeout(self): + urls = [] + def mock_http_request(url, *args, **kwargs): + urls.append(url) + raise socket.error('timed out') + self.mock(self.rietveld.rpc_server._http, 'request', mock_http_request) + + with self.assertRaises(socket.error): + self.rietveld.get('/foo') + + self.assertEqual(2, len(urls)) # maxtries was 2 + + def test_other_socket_error(self): + urls = [] + def mock_http_request(url, *args, **kwargs): + urls.append(url) + raise socket.error('something else') + self.mock(self.rietveld.rpc_server._http, 'request', mock_http_request) + + with self.assertRaises(socket.error): + self.rietveld.get('/foo') + + self.assertEqual(1, len(urls)) + + if __name__ == '__main__': logging.basicConfig(level=[ logging.ERROR, logging.INFO, logging.DEBUG][min(2, sys.argv.count('-v'))])