#!/usr/bin/env python3 # Copyright 2024 The Chromium Authors. All rights reserved. # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. ''' Tool to squash all branches and their downstream branches. Useful to avoid potential conflicts during a git rebase-update with multiple stacked CLs. ''' import argparse import collections import git_common as git import sys # Returns the list of branches that have diverged from their respective upstream # branch. def get_diverged_branches(tree): diverged_branches = [] for branch, upstream_branch in tree.items(): # If the merge base of a branch and its upstream is not equal to the # upstream, then it means that both branch diverged. upstream_branch_hash = git.hash_one(upstream_branch) merge_base_hash = git.hash_one(git.get_or_create_merge_base(branch)) if upstream_branch_hash != merge_base_hash: diverged_branches.append(branch) return diverged_branches # Returns a dictionary that contains the hash of every branch before the # squashing started. def get_initial_hashes(tree): initial_hashes = {} for branch, upstream_branch in tree.items(): initial_hashes[branch] = git.hash_one(branch) initial_hashes[upstream_branch] = git.hash_one(upstream_branch) return initial_hashes # Returns a dictionary that contains the downstream branches of every branch. def get_downstream_branches(tree): downstream_branches = collections.defaultdict(list) for branch, upstream_branch in tree.items(): downstream_branches[upstream_branch].append(branch) return downstream_branches # Squash a branch, taking care to rebase the branch on top of the new commit # position of its upstream branch. def squash_branch(branch, initial_hashes): print('Squashing branch %s.' % branch) assert initial_hashes[branch] == git.hash_one(branch) upstream_branch = git.upstream(branch) old_upstream_branch = initial_hashes[upstream_branch] # Because the branch's upstream has potentially changed from squashing it, # the current branch is rebased on top of the new upstream. git.run('rebase', '--onto', upstream_branch, old_upstream_branch, branch, '--update-refs') # Now do the squashing. git.run('checkout', branch) git.squash_current_branch() # Squashes all branches that are part of the subtree starting at `branch`. def squash_subtree(branch, initial_hashes, downstream_branches): # The upstream default never has to be squashed (e.g. origin/main). if branch != git.upstream_default(): squash_branch(branch, initial_hashes) # Recurse on downstream branches, if any. for downstream_branch in downstream_branches[branch]: squash_subtree(downstream_branch, initial_hashes, downstream_branches) def main(args=None): parser = argparse.ArgumentParser() parser.add_argument('--ignore-no-upstream', action='store_true', help='Allows proceeding if any branch has no ' 'upstreams.') parser.add_argument('--branch', '-b', type=str, default=git.current_branch(), help='The name of the branch who\'s subtree must be ' 'squashed. Defaults to the current branch.') opts = parser.parse_args(args) if git.is_dirty_git_tree('squash-branch-tree'): return 1 branches_without_upstream, tree = git.get_branch_tree() if not opts.ignore_no_upstream and branches_without_upstream: print('Cannot use `git squash-branch-tree` since the following\n' 'branches don\'t have an upstream:') for branch in branches_without_upstream: print(f' - {branch}') print('Use --ignore-no-upstream to ignore this check and proceed.') return 1 diverged_branches = get_diverged_branches(tree) if diverged_branches: print('Cannot use `git squash-branch-tree` since the following\n' 'branches have diverged from their upstream and could cause\n' 'conflicts:') for diverged_branch in diverged_branches: print(f' - {diverged_branch}') return 1 # Before doing the squashing, save the current branch checked out branch so # we can go back to it at the end. return_branch = git.current_branch() initial_hashes = get_initial_hashes(tree) downstream_branches = get_downstream_branches(tree) squash_subtree(opts.branch, initial_hashes, downstream_branches) git.run('checkout', return_branch) return 0 if __name__ == '__main__': # pragma: no cover try: sys.exit(main(sys.argv[1:])) except KeyboardInterrupt: sys.stderr.write('interrupted\n') sys.exit(1)