From 113b6826a38a1054d355b63d4300817b604b8aae Mon Sep 17 00:00:00 2001 From: Steven Arcangeli Date: Thu, 18 Jan 2024 18:03:45 -0800 Subject: [PATCH] fix(pull): pull all branches in stack --- gitstack.py | 137 +++++++++++++++++++++++++++++++--------------------- 1 file changed, 82 insertions(+), 55 deletions(-) diff --git a/gitstack.py b/gitstack.py index d1d35b7..3ca9c28 100755 --- a/gitstack.py +++ b/gitstack.py @@ -190,10 +190,13 @@ def rev_parse(ref: str) -> Optional[str]: return None @staticmethod - def list_branches() -> Dict[str, str]: + def list_branches(all: bool = False) -> Dict[str, str]: """Mapping of branch name to ref""" ret = {} - for b in git_lines("branch", "--format=%(refname:short) %(objectname)"): + args = ["branch", "--format=%(refname:short) %(objectname)"] + if all: + args.append("-a") + for b in git_lines(*args): if b.startswith("(HEAD detached at"): continue name, ref = b.split() @@ -985,46 +988,62 @@ def load_prs(self) -> None: @classmethod def load(cls) -> "Repo": - diffs: Dict[str, Diff] = {} branches = git.list_branches() - tips = {v: k for k, v in branches.items()} - for branch in branches: - if branch != git.get_main_branch(): - diffs[branch] = Diff.from_branch(branch, tips) - stacks = [] - diffs_by_prev: Dict[str, Diff] = {} - diff: Optional[Diff] = None - for diff in diffs.values(): - if diff.label.prev_branch is not None: - diffs_by_prev[diff.label.prev_branch] = diff - - while diffs: - diff = None - # Find a diff with no previous branch - for d in diffs.values(): - if d.label.prev_branch not in diffs or d.label.prev_branch is None: - diff = d - break - if diff is None: - diff = diffs[next(iter(diffs))] + stacks = create_stacks(branches) + return cls(stacks) + + +def strip_remote(branch: str) -> str: + if branch.startswith("origin/"): + return branch[7:] + else: + return branch + + +def create_stacks(branches: Dict[str, str]) -> List[Stack]: + diffs: Dict[str, Diff] = {} + tips = {v: k for k, v in branches.items()} + # We have to strip the remote prefix off of the keys in diffs and diffs_by_prev + # because the commit labels themselves do not have the remote prefix. Regardless of + # the real branch name, we have to always look it up using the representation from + # the commit labels, which is local-only. + for branch in branches: + if strip_remote(branch) != git.get_main_branch(): + diffs[strip_remote(branch)] = Diff.from_branch(branch, tips) + stacks = [] + diffs_by_prev: Dict[str, Diff] = {} + diff: Optional[Diff] = None + for diff in diffs.values(): + if diff.label.prev_branch is not None: + diffs_by_prev[diff.label.prev_branch] = diff + + while diffs: + diff = None + # Find a diff with no previous branch + for d in diffs.values(): + if d.label.prev_branch not in diffs or d.label.prev_branch is None: + diff = d + break + if diff is None: + diff = diffs[next(iter(diffs))] + assert diff.branch is not None + del diffs[strip_remote(diff.branch.name)] + + diff_list = [diff] + while strip_remote(diff.branch.name) in diffs_by_prev: + next_diff = diffs_by_prev.pop(strip_remote(diff.branch.name)) + assert next_diff.branch is not None + if strip_remote(next_diff.branch.name) in diffs: + del diffs[strip_remote(next_diff.branch.name)] + diff_list.append(next_diff) + else: + # There is probably a cycle somewhere + break + diff = next_diff assert diff.branch is not None - del diffs[diff.branch.name] - - diff_list = [diff] - while diff.branch.name in diffs_by_prev: - next_diff = diffs_by_prev.pop(diff.branch.name) - assert next_diff.branch is not None - if next_diff.branch.name in diffs: - del diffs[next_diff.branch.name] - diff_list.append(next_diff) - else: - # There is probably a cycle somewhere - break - diff = next_diff - assert diff.branch is not None - stacks.append(Stack(diff_list)) - return cls(stacks) + stacks.append(Stack(diff_list)) + return stacks @dataclass @@ -1682,27 +1701,35 @@ def run(self, branch: Optional[str] = None, force: bool = False) -> None: target = branch git.fetch() - all_branches = git.list_branches() + all_branches = git.list_branches(all=True) + remote_branches = { + k: v for k, v in all_branches.items() if k.startswith("origin/") + } + stacks = create_stacks(remote_branches) + stack = None + for candidate in stacks: + branches = [b.name for b in candidate.branches()] + if "origin/" + target in branches: + stack = candidate + break + if stack is None: + print_err("No stack found for branch", target) + sys.exit(1) + + for remote_branch in stack.branches(): + assert remote_branch.name.startswith("origin/") + local_branch = remote_branch.name[7:] - while branch: - if branch not in all_branches: - git.create_branch(branch, "HEAD") + if local_branch not in all_branches: + git.create_branch(local_branch, "HEAD") else: - git.switch_branch(branch) + git.switch_branch(local_branch) if force: - git.reset("origin/" + branch, hard=True) + git.reset(remote_branch.name, hard=True) else: - git.fast_forward("origin/" + branch) - - branch = None - refs = git.refs_between(git.merge_base("HEAD"), "HEAD") - refs.reverse() - for ref in refs: - label = Diff.parse_labels(ref) - if label.prev_branch: - branch = label.prev_branch - break + git.fast_forward(remote_branch.name) + git.switch_branch(target)