diff --git a/release-management/scripts/git-local-merge.py b/release-management/scripts/git-local-merge.py index 3333f36..7d6bdeb 100644 --- a/release-management/scripts/git-local-merge.py +++ b/release-management/scripts/git-local-merge.py @@ -1,9 +1,14 @@ #!/usr/bin/env python3 from __future__ import annotations +if __name__ != "__main__": + raise ImportError(f"{__name__} should not be used as a module.") + import argparse import json +import os import shutil +import signal import subprocess import sys from typing import NoReturn @@ -38,10 +43,25 @@ Merge pull request #{id} from {branch} def main() -> NoReturn: - parser = argparse.ArgumentParser(description="Locally merge multiple GitHub PRs.") - parser.add_argument("ids", nargs="+", help="PR ids to merge.") + parser = argparse.ArgumentParser(prog="git-local-merge", description="Locally merge multiple GitHub PRs.") + parser.add_argument("ids", nargs="*", help="PR ids to merge.", type=int) + parser.add_argument( + "-f", + "--file", + help="Path to a file containing PR ids.", + type=argparse.FileType(encoding="utf-8"), + ) args = parser.parse_args() + ids: set[int] = set(args.ids) + if args.file: + try: + ids.update(int(id) for id in args.file.read().split()) + except ValueError: + parser.error("File contained invalid int values.") + if not ids: + parser.error("No ids provided.") + if subprocess.run(["git", "checkout"], stdout=subprocess.PIPE).returncode != 0: sys.exit(1) @@ -61,14 +81,11 @@ def main() -> NoReturn: out = subprocess.run(["gh", "repo", "set-default", "--view"], capture_output=True) if out.stderr: subprocess.run(["gh", "repo", "set-default"]) - out = subprocess.run( - ["gh", "repo", "set-default", "--view"], capture_output=True - ) + out = subprocess.run(["gh", "repo", "set-default", "--view"], capture_output=True) if out.stderr: print("Failed to setup default remote repository!", file=sys.stderr) sys.exit(1) - ids = set([int(id) for id in args.ids]) failed: set[int] = set() prs: list[PullRequestInfo] = [] @@ -81,9 +98,7 @@ def main() -> NoReturn: failed.add(id) for pr in prs: - subprocess.run( - ["gh", "pr", "checkout", str(pr.id), "--branch", pr.branch, "--force"] - ) + subprocess.run(["gh", "pr", "checkout", str(pr.id), "--branch", pr.branch, "--force"]) subprocess.run(["git", "checkout", BASE_BRANCH]) for pr in prs: @@ -98,5 +113,8 @@ def main() -> NoReturn: sys.exit(len(failed)) -if __name__ == "__main__": +try: main() +except KeyboardInterrupt: + signal.signal(signal.SIGINT, signal.SIG_DFL) + os.kill(os.getpid(), signal.SIGINT)