diff --git a/checkout.py b/checkout.py index 09ea9c3fa..9f21acc90 100644 --- a/checkout.py +++ b/checkout.py @@ -59,7 +59,12 @@ class CheckoutBase(object): # Set to None to have verbose output. VOID = subprocess2.VOID - def __init__(self, root_dir, project_name): + def __init__(self, root_dir, project_name, post_processors): + """ + Args: + post_processor: list of lambda(checkout, patches) to call on each of the + modified files. + """ self.root_dir = root_dir self.project_name = project_name if self.project_name is None: @@ -68,6 +73,7 @@ class CheckoutBase(object): self.project_path = os.path.join(self.root_dir, self.project_name) # Only used for logging purposes. self._last_seen_revision = None + self.post_processors = None assert self.root_dir assert self.project_path @@ -82,7 +88,7 @@ class CheckoutBase(object): """ raise NotImplementedError() - def apply_patch(self, patches, post_processor=None): + def apply_patch(self, patches): """Applies a patch and returns the list of modified files. This function should throw patch.UnsupportedPatchFormat or @@ -90,8 +96,6 @@ class CheckoutBase(object): Args: patches: patch.PatchSet object. - post_processor: list of lambda(checkout, patches) to call on each of the - modified files. """ raise NotImplementedError() @@ -109,9 +113,8 @@ class RawCheckout(CheckoutBase): """Stubbed out.""" pass - def apply_patch(self, patches, post_processor=None): + def apply_patch(self, patches): """Ignores svn properties.""" - post_processor = post_processor or [] for p in patches: try: stdout = '' @@ -137,7 +140,7 @@ class RawCheckout(CheckoutBase): elif p.is_new and not os.path.exists(filepath): # There is only a header. Just create the file. open(filepath, 'w').close() - for post in post_processor: + for post in (self.post_processors or []): post(self, p) except OSError, e: raise PatchApplicationFailed(p.filename, '%s%s' % (stdout, e)) @@ -233,8 +236,9 @@ class SvnMixIn(object): class SvnCheckout(CheckoutBase, SvnMixIn): """Manages a subversion checkout.""" - def __init__(self, root_dir, project_name, commit_user, commit_pwd, svn_url): - super(SvnCheckout, self).__init__(root_dir, project_name) + def __init__(self, root_dir, project_name, commit_user, commit_pwd, svn_url, + post_processors=None): + super(SvnCheckout, self).__init__(root_dir, project_name, post_processors) self.commit_user = commit_user self.commit_pwd = commit_pwd self.svn_url = svn_url @@ -252,8 +256,7 @@ class SvnCheckout(CheckoutBase, SvnMixIn): self._last_seen_revision = revision return revision - def apply_patch(self, patches, post_processor=None): - post_processor = post_processor or [] + def apply_patch(self, patches): for p in patches: try: # It is important to use credentials=False otherwise credentials could @@ -306,7 +309,7 @@ class SvnCheckout(CheckoutBase, SvnMixIn): params = value.split('=', 1) stdout += self._check_output_svn( ['propset'] + params + [p.filename], credentials=False) - for post in post_processor: + for post in (self.post_processors or []): post(self, p) except OSError, e: raise PatchApplicationFailed(p.filename, '%s%s' % (stdout, e)) @@ -369,8 +372,10 @@ class SvnCheckout(CheckoutBase, SvnMixIn): class GitCheckoutBase(CheckoutBase): """Base class for git checkout. Not to be used as-is.""" - def __init__(self, root_dir, project_name, remote_branch): - super(GitCheckoutBase, self).__init__(root_dir, project_name) + def __init__(self, root_dir, project_name, remote_branch, + post_processors=None): + super(GitCheckoutBase, self).__init__( + root_dir, project_name, post_processors) # There is no reason to not hardcode it. self.remote = 'origin' self.remote_branch = remote_branch @@ -391,14 +396,13 @@ class GitCheckoutBase(CheckoutBase): if self.working_branch in branches: self._call_git(['branch', '-D', self.working_branch]) - def apply_patch(self, patches, post_processor=None): + def apply_patch(self, patches): """Applies a patch on 'working_branch' and switch to it. Also commits the changes on the local branch. Ignores svn properties and raise an exception on unexpected ones. """ - post_processor = post_processor or [] # It this throws, the checkout is corrupted. Maybe worth deleting it and # trying again? if self.remote_branch: @@ -435,7 +439,7 @@ class GitCheckoutBase(CheckoutBase): p.filename, 'Cannot apply svn property %s to file %s.' % ( prop[0], p.filename)) - for post in post_processor: + for post in (self.post_processors or []): post(self, p) except OSError, e: raise PatchApplicationFailed(p.filename, '%s%s' % (stdout, e)) @@ -492,10 +496,10 @@ class GitSvnCheckoutBase(GitCheckoutBase, SvnMixIn): def __init__(self, root_dir, project_name, remote_branch, commit_user, commit_pwd, - svn_url, trunk): + svn_url, trunk, post_processors=None): """trunk is optional.""" super(GitSvnCheckoutBase, self).__init__( - root_dir, project_name + '.git', remote_branch) + root_dir, project_name + '.git', remote_branch, post_processors) self.commit_user = commit_user self.commit_pwd = commit_pwd # svn_url in this case is the root of the svn repository. @@ -583,11 +587,11 @@ class GitSvnPremadeCheckout(GitSvnCheckoutBase): def __init__(self, root_dir, project_name, remote_branch, commit_user, commit_pwd, - svn_url, trunk, git_url): + svn_url, trunk, git_url, post_processors=None): super(GitSvnPremadeCheckout, self).__init__( root_dir, project_name, remote_branch, commit_user, commit_pwd, - svn_url, trunk) + svn_url, trunk, post_processors) self.git_url = git_url assert self.git_url @@ -627,11 +631,11 @@ class GitSvnCheckout(GitSvnCheckoutBase): def __init__(self, root_dir, project_name, commit_user, commit_pwd, - svn_url, trunk): + svn_url, trunk, post_processors=None): super(GitSvnCheckout, self).__init__( root_dir, project_name, 'trunk', commit_user, commit_pwd, - svn_url, trunk) + svn_url, trunk, post_processors) def prepare(self): """Creates the initial checkout for the repo.""" @@ -663,8 +667,8 @@ class ReadOnlyCheckout(object): def get_settings(self, key): return self.checkout.get_settings(key) - def apply_patch(self, patches, post_processor=None): - return self.checkout.apply_patch(patches, post_processor) + def apply_patch(self, patches): + return self.checkout.apply_patch(patches) def commit(self, message, user): # pylint: disable=R0201 logging.info('Would have committed for %s with message: %s' % ( diff --git a/tests/checkout_test.py b/tests/checkout_test.py index 59a3f529a..4ea73a06e 100755 --- a/tests/checkout_test.py +++ b/tests/checkout_test.py @@ -236,10 +236,11 @@ class BaseTest(fake_repos.FakeReposTestBase): def _test_process(self, co): """Makes sure the process lambda is called correctly.""" + co.post_processors = [lambda *args: results.append(args)] co.prepare() ps = self.get_patches() results = [] - co.apply_patch(ps, [lambda *args: results.append(args)]) + co.apply_patch(ps) expected = [(co, p) for p in ps.patches] self.assertEquals(expected, results) @@ -479,7 +480,7 @@ class RawCheckout(SvnBaseTest): self.base_co.prepare() def _get_co(self, read_only): - co = checkout.RawCheckout(self.root_dir, self.name) + co = checkout.RawCheckout(self.root_dir, self.name, None) if read_only: return checkout.ReadOnlyCheckout(co) return co