1# SPDX-License-Identifier: GPL-2.0+ 2# 3# Copyright 2025 Simon Glass <sjg@chromium.org> 4# 5"""Helper functions for handling the 'series' subcommand 6""" 7 8import asyncio 9from collections import OrderedDict, defaultdict, namedtuple 10from datetime import datetime 11import hashlib 12import os 13import re 14import sys 15import time 16from types import SimpleNamespace 17 18import aiohttp 19import pygit2 20from pygit2.enums import CheckoutStrategy 21 22from u_boot_pylib import gitutil 23from u_boot_pylib import terminal 24from u_boot_pylib import tout 25 26from patman import patchstream 27from patman.database import Database, Pcommit, SerVer 28from patman import patchwork 29from patman.series import Series 30from patman import status 31 32 33# Tag to use for Change IDs 34CHANGE_ID_TAG = 'Change-Id' 35 36# Length of hash to display 37HASH_LEN = 10 38 39# Shorter version of some states, to save horizontal space 40SHORTEN_STATE = { 41 'handled-elsewhere': 'elsewhere', 42 'awaiting-upstream': 'awaiting', 43 'not-applicable': 'n/a', 44 'changes-requested': 'changes', 45} 46 47# Summary info returned from Cseries.link_auto_all() 48AUTOLINK = namedtuple('autolink', 'name,version,link,desc,result') 49 50 51def oid(oid_val): 52 """Convert a hash string into a shortened hash 53 54 The number of hex digits git uses for showing hashes depends on the size of 55 the repo. For the purposes of showing hashes to the user in lists, we use a 56 fixed value for now 57 58 Args: 59 str or Pygit2.oid: Hash value to shorten 60 61 Return: 62 str: Shortened hash 63 """ 64 return str(oid_val)[:HASH_LEN] 65 66 67def split_name_version(in_name): 68 """Split a branch name into its series name and its version 69 70 For example: 71 'series' returns ('series', 1) 72 'series3' returns ('series', 3) 73 Args: 74 in_name (str): Name to parse 75 76 Return: 77 tuple: 78 str: series name 79 int: series version, or None if there is none in in_name 80 """ 81 m_ver = re.match(r'([^0-9]*)(\d*)', in_name) 82 version = None 83 if m_ver: 84 name = m_ver.group(1) 85 if m_ver.group(2): 86 version = int(m_ver.group(2)) 87 else: 88 name = in_name 89 return name, version 90 91 92class CseriesHelper: 93 """Helper functions for Cseries 94 95 This class handles database read/write as well as operations in a git 96 directory to update series information. 97 """ 98 def __init__(self, topdir=None, colour=terminal.COLOR_IF_TERMINAL): 99 """Set up a new CseriesHelper 100 101 Args: 102 topdir (str): Top-level directory of the repo 103 colour (terminal.enum): Whether to enable ANSI colour or not 104 105 Properties: 106 gitdir (str): Git directory (typically topdir + '/.git') 107 db (Database): Database handler 108 col (terminal.Colour): Colour object 109 _fake_time (float): Holds the current fake time for tests, in 110 seconds 111 _fake_sleep (func): Function provided by a test; called to fake a 112 'time.sleep()' call and take whatever action it wants to take. 113 The only argument is the (Float) time to sleep for; it returns 114 nothing 115 loop (asyncio event loop): Loop used for Patchwork operations 116 """ 117 self.topdir = topdir 118 self.gitdir = None 119 self.db = None 120 self.col = terminal.Color(colour) 121 self._fake_time = None 122 self._fake_sleep = None 123 self.fake_now = None 124 self.loop = asyncio.get_event_loop() 125 126 def open_database(self): 127 """Open the database ready for use""" 128 if not self.topdir: 129 self.topdir = gitutil.get_top_level() 130 if not self.topdir: 131 raise ValueError('No git repo detected in current directory') 132 self.gitdir = os.path.join(self.topdir, '.git') 133 fname = f'{self.topdir}/.patman.db' 134 135 # For the first instance, start it up with the expected schema 136 self.db, is_new = Database.get_instance(fname) 137 if is_new: 138 self.db.start() 139 else: 140 # If a previous test has already checked the schema, just open it 141 self.db.open_it() 142 143 def close_database(self): 144 """Close the database""" 145 if self.db: 146 self.db.close() 147 148 def commit(self): 149 """Commit changes to the database""" 150 self.db.commit() 151 152 def rollback(self): 153 """Roll back changes to the database""" 154 self.db.rollback() 155 156 def set_fake_time(self, fake_sleep): 157 """Setup the fake timer 158 159 Args: 160 fake_sleep (func(float)): Function to call to fake a sleep 161 """ 162 self._fake_time = 0 163 self._fake_sleep = fake_sleep 164 165 def inc_fake_time(self, inc_s): 166 """Increment the fake time 167 168 Args: 169 inc_s (float): Amount to increment the fake time by 170 """ 171 self._fake_time += inc_s 172 173 def get_time(self): 174 """Get the current time, fake or real 175 176 This function should always be used to read the time so that faking the 177 time works correctly in tests. 178 179 Return: 180 float: Fake time, if time is being faked, else real time 181 """ 182 if self._fake_time is not None: 183 return self._fake_time 184 return time.monotonic() 185 186 def sleep(self, time_s): 187 """Sleep for a while 188 189 This function should always be used to sleep so that faking the time 190 works correctly in tests. 191 192 Args: 193 time_s (float): Amount of seconds to sleep for 194 """ 195 print(f'Sleeping for {time_s} seconds') 196 if self._fake_time is not None: 197 self._fake_sleep(time_s) 198 else: 199 time.sleep(time_s) 200 201 def get_now(self): 202 """Get the time now 203 204 This function should always be used to read the datetime, so that 205 faking the time works correctly in tests 206 207 Return: 208 DateTime object 209 """ 210 if self.fake_now: 211 return self.fake_now 212 return datetime.now() 213 214 def get_ser_ver_list(self): 215 """Get a list of patchwork entries from the database 216 217 Return: 218 list of SER_VER 219 """ 220 return self.db.ser_ver_get_list() 221 222 def get_ser_ver_dict(self): 223 """Get a dict of patchwork entries from the database 224 225 Return: dict contain all records: 226 key (int): ser_ver id 227 value (SER_VER): Information about one ser_ver record 228 """ 229 svlist = self.get_ser_ver_list() 230 svdict = {} 231 for sver in svlist: 232 svdict[sver.idnum] = sver 233 return svdict 234 235 def get_upstream_dict(self): 236 """Get a list of upstream entries from the database 237 238 Return: 239 OrderedDict: 240 key (str): upstream name 241 value (str): url 242 """ 243 return self.db.upstream_get_dict() 244 245 def get_pcommit_dict(self, find_svid=None): 246 """Get a dict of pcommits entries from the database 247 248 Args: 249 find_svid (int): If not None, finds the records associated with a 250 particular series and version 251 252 Return: 253 OrderedDict: 254 key (int): record ID if find_svid is None, else seq 255 value (PCOMMIT): record data 256 """ 257 pcdict = OrderedDict() 258 for rec in self.db.pcommit_get_list(find_svid): 259 if find_svid is not None: 260 pcdict[rec.seq] = rec 261 else: 262 pcdict[rec.idnum] = rec 263 return pcdict 264 265 def _get_series_info(self, idnum): 266 """Get information for a series from the database 267 268 Args: 269 idnum (int): Series ID to look up 270 271 Return: tuple: 272 str: Series name 273 str: Series description 274 275 Raises: 276 ValueError: Series is not found 277 """ 278 return self.db.series_get_info(idnum) 279 280 def prep_series(self, name, end=None): 281 """Prepare to work with a series 282 283 Args: 284 name (str): Branch name with version appended, e.g. 'fix2' 285 end (str or None): Commit to end at, e.g. 'my_branch~16'. Only 286 commits up to that are processed. None to process commits up to 287 the upstream branch 288 289 Return: tuple: 290 str: Series name, e.g. 'fix' 291 Series: Collected series information, including name 292 int: Version number, e.g. 2 293 str: Message to show 294 """ 295 ser, version = self._parse_series_and_version(name, None) 296 if not name: 297 name = self._get_branch_name(ser.name, version) 298 299 # First check we have a branch with this name 300 if not gitutil.check_branch(name, git_dir=self.gitdir): 301 raise ValueError(f"No branch named '{name}'") 302 303 count = gitutil.count_commits_to_branch(name, self.gitdir, end) 304 if not count: 305 raise ValueError('Cannot detect branch automatically: ' 306 'Perhaps use -U <upstream-commit> ?') 307 308 series = patchstream.get_metadata(name, 0, count, git_dir=self.gitdir) 309 self._copy_db_fields_to(series, ser) 310 msg = None 311 if end: 312 repo = pygit2.init_repository(self.gitdir) 313 target = repo.revparse_single(end) 314 first_line = target.message.splitlines()[0] 315 msg = f'Ending before {oid(target.id)} {first_line}' 316 317 return name, series, version, msg 318 319 def _copy_db_fields_to(self, series, in_series): 320 """Copy over fields used by Cseries from one series to another 321 322 This copes desc, idnum and name 323 324 Args: 325 series (Series): Series to copy to 326 in_series (Series): Series to copy from 327 """ 328 series.desc = in_series.desc 329 series.idnum = in_series.idnum 330 series.name = in_series.name 331 332 def _handle_mark(self, branch_name, in_series, version, mark, 333 allow_unmarked, force_version, dry_run): 334 """Handle marking a series, checking for unmarked commits, etc. 335 336 Args: 337 branch_name (str): Name of branch to sync, or None for current one 338 in_series (Series): Series object 339 version (int): branch version, e.g. 2 for 'mychange2' 340 mark (bool): True to mark each commit with a change ID 341 allow_unmarked (str): True to not require each commit to be marked 342 force_version (bool): True if ignore a Series-version tag that 343 doesn't match its branch name 344 dry_run (bool): True to do a dry run 345 346 Returns: 347 Series: New series object, if the series was marked; 348 copy_db_fields_to() is used to copy fields over 349 350 Raises: 351 ValueError: Series being unmarked when it should be marked, etc. 352 """ 353 series = in_series 354 if 'version' in series and int(series.version) != version: 355 msg = (f"Series name '{branch_name}' suggests version {version} " 356 f"but Series-version tag indicates {series.version}") 357 if not force_version: 358 raise ValueError(msg + ' (see --force-version)') 359 360 tout.warning(msg) 361 tout.warning(f'Updating Series-version tag to version {version}') 362 self.update_series(branch_name, series, int(series.version), 363 new_name=None, dry_run=dry_run, 364 add_vers=version) 365 366 # Collect the commits again, as the hashes have changed 367 series = patchstream.get_metadata(branch_name, 0, 368 len(series.commits), 369 git_dir=self.gitdir) 370 self._copy_db_fields_to(series, in_series) 371 372 if mark: 373 add_oid = self._mark_series(branch_name, series, dry_run=dry_run) 374 375 # Collect the commits again, as the hashes have changed 376 series = patchstream.get_metadata(add_oid, 0, len(series.commits), 377 git_dir=self.gitdir) 378 self._copy_db_fields_to(series, in_series) 379 380 bad_count = 0 381 for commit in series.commits: 382 if not commit.change_id: 383 bad_count += 1 384 if bad_count and not allow_unmarked: 385 raise ValueError( 386 f'{bad_count} commit(s) are unmarked; please use -m or -M') 387 388 return series 389 390 def _add_series_commits(self, series, svid): 391 """Add a commits from a series into the database 392 393 Args: 394 series (Series): Series containing commits to add 395 svid (int): ser_ver-table ID to use for each commit 396 """ 397 to_add = [Pcommit(None, seq, commit.subject, None, commit.change_id, 398 None, None, None) 399 for seq, commit in enumerate(series.commits)] 400 401 self.db.pcommit_add_list(svid, to_add) 402 403 def get_series_by_name(self, name, include_archived=False): 404 """Get a Series object from the database by name 405 406 Args: 407 name (str): Name of series to get 408 include_archived (bool): True to search in archives series 409 410 Return: 411 Series: Object containing series info, or None if none 412 """ 413 idnum = self.db.series_find_by_name(name, include_archived) 414 if not idnum: 415 return None 416 name, desc = self.db.series_get_info(idnum) 417 418 return Series.from_fields(idnum, name, desc) 419 420 def _get_branch_name(self, name, version): 421 """Get the branch name for a particular version 422 423 Args: 424 name (str): Base name of branch 425 version (int): Version number to use 426 """ 427 return name + (f'{version}' if version > 1 else '') 428 429 def _ensure_version(self, ser, version): 430 """Ensure that a version exists in a series 431 432 Args: 433 ser (Series): Series information, with idnum and name used here 434 version (int): Version to check 435 436 Returns: 437 list of int: List of versions 438 """ 439 versions = self._get_version_list(ser.idnum) 440 if version not in versions: 441 raise ValueError( 442 f"Series '{ser.name}' does not have a version {version}") 443 return versions 444 445 def _set_link(self, ser_id, name, version, link, update_commit, 446 dry_run=False): 447 """Add / update a series-links link for a series 448 449 Args: 450 ser_id (int): Series ID number 451 name (str): Series name (used to find the branch) 452 version (int): Version number (used to update the database) 453 link (str): Patchwork link-string for the series 454 update_commit (bool): True to update the current commit with the 455 link 456 dry_run (bool): True to do a dry run 457 458 Return: 459 bool: True if the database was update, False if the ser_id or 460 version was not found 461 """ 462 if update_commit: 463 branch_name = self._get_branch_name(name, version) 464 _, ser, max_vers, _ = self.prep_series(branch_name) 465 self.update_series(branch_name, ser, max_vers, add_vers=version, 466 dry_run=dry_run, add_link=link) 467 if link is None: 468 link = '' 469 updated = 1 if self.db.ser_ver_set_link(ser_id, version, link) else 0 470 if dry_run: 471 self.rollback() 472 else: 473 self.commit() 474 475 return updated 476 477 def _get_autolink_dict(self, sdict, link_all_versions): 478 """Get a dict of ser_vers to fetch, along with their patchwork links 479 480 Note that this returns items that already have links, as well as those 481 without links 482 483 Args: 484 sdict: 485 key: series ID 486 value: Series with idnum, name and desc filled out 487 link_all_versions (bool): True to sync all versions of a series, 488 False to sync only the latest version 489 490 Return: tuple: 491 dict: 492 key (int): svid 493 value (tuple): 494 int: series ID 495 str: series name 496 int: series version 497 str: patchwork link for the series, or None if none 498 desc: cover-letter name / series description 499 """ 500 svdict = self.get_ser_ver_dict() 501 to_fetch = {} 502 503 if link_all_versions: 504 for svinfo in self.get_ser_ver_list(): 505 ser = sdict[svinfo.series_id] 506 507 pwc = self.get_pcommit_dict(svinfo.idnum) 508 count = len(pwc) 509 branch = self._join_name_version(ser.name, svinfo.version) 510 series = patchstream.get_metadata(branch, 0, count, 511 git_dir=self.gitdir) 512 self._copy_db_fields_to(series, ser) 513 514 to_fetch[svinfo.idnum] = (svinfo.series_id, series.name, 515 svinfo.version, svinfo.link, series) 516 else: 517 # Find the maximum version for each series 518 max_vers = self._series_all_max_versions() 519 520 # Get a list of links to fetch 521 for svid, ser_id, version in max_vers: 522 svinfo = svdict[svid] 523 ser = sdict[ser_id] 524 525 pwc = self.get_pcommit_dict(svid) 526 count = len(pwc) 527 branch = self._join_name_version(ser.name, version) 528 series = patchstream.get_metadata(branch, 0, count, 529 git_dir=self.gitdir) 530 self._copy_db_fields_to(series, ser) 531 532 to_fetch[svid] = (ser_id, series.name, version, svinfo.link, 533 series) 534 return to_fetch 535 536 def _get_version_list(self, idnum): 537 """Get a list of the versions available for a series 538 539 Args: 540 idnum (int): ID of series to look up 541 542 Return: 543 str: List of versions 544 """ 545 if idnum is None: 546 raise ValueError('Unknown series idnum') 547 return self.db.series_get_version_list(idnum) 548 549 def _join_name_version(self, in_name, version): 550 """Convert a series name plus a version into a branch name 551 552 For example: 553 ('series', 1) returns 'series' 554 ('series', 3) returns 'series3' 555 556 Args: 557 in_name (str): Series name 558 version (int): Version number 559 560 Return: 561 str: associated branch name 562 """ 563 if version == 1: 564 return in_name 565 return f'{in_name}{version}' 566 567 def _parse_series(self, name, include_archived=False): 568 """Parse the name of a series, or detect it from the current branch 569 570 Args: 571 name (str or None): name of series 572 include_archived (bool): True to search in archives series 573 574 Return: 575 Series: New object with the name set; idnum is also set if the 576 series exists in the database 577 """ 578 if not name: 579 name = gitutil.get_branch(self.gitdir) 580 name, _ = split_name_version(name) 581 ser = self.get_series_by_name(name, include_archived) 582 if not ser: 583 ser = Series() 584 ser.name = name 585 return ser 586 587 def _parse_series_and_version(self, in_name, in_version): 588 """Parse name and version of a series, or detect from current branch 589 590 Figures out the name from in_name, or if that is None, from the current 591 branch. 592 593 Uses the version in_version, or if that is None, uses the int at the 594 end of the name (e.g. 'series' is version 1, 'series4' is version 4) 595 596 Args: 597 in_name (str or None): name of series 598 in_version (str or None): version of series 599 600 Return: 601 tuple: 602 Series: New object with the name set; idnum is also set if the 603 series exists in the database 604 int: Series version-number detected from the name 605 (e.g. 'fred' is version 1, 'fred2' is version 2) 606 """ 607 name = in_name 608 if not name: 609 name = gitutil.get_branch(self.gitdir) 610 if not name: 611 raise ValueError('No branch detected: please use -s <series>') 612 name, version = split_name_version(name) 613 if not name: 614 raise ValueError(f"Series name '{in_name}' cannot be a number, " 615 f"use '<name><version>'") 616 if in_version: 617 if version and version != in_version: 618 tout.warning( 619 f"Version mismatch: -V has {in_version} but branch name " 620 f'indicates {version}') 621 version = in_version 622 if not version: 623 version = 1 624 if version > 99: 625 raise ValueError(f"Version {version} exceeds 99") 626 ser = self.get_series_by_name(name) 627 if not ser: 628 ser = Series() 629 ser.name = name 630 return ser, version 631 632 def _series_get_version_stats(self, idnum, vers): 633 """Get the stats for a series 634 635 Args: 636 idnum (int): ID number of series to process 637 vers (int): Version number to process 638 639 Return: 640 tuple: 641 str: Status string, '<accepted>/<count>' 642 OrderedDict: 643 key (int): record ID if find_svid is None, else seq 644 value (PCOMMIT): record data 645 """ 646 svid, link = self._get_series_svid_link(idnum, vers) 647 pwc = self.get_pcommit_dict(svid) 648 count = len(pwc.values()) 649 if link: 650 accepted = 0 651 for pcm in pwc.values(): 652 accepted += pcm.state == 'accepted' 653 else: 654 accepted = '-' 655 return f'{accepted}/{count}', pwc 656 657 def get_series_svid(self, series_id, version): 658 """Get the patchwork ID of a series version 659 660 Args: 661 series_id (int): id of the series to look up 662 version (int): version number to look up 663 664 Return: 665 str: link found 666 667 Raises: 668 ValueError: No matching series found 669 """ 670 return self._get_series_svid_link(series_id, version)[0] 671 672 def _get_series_svid_link(self, series_id, version): 673 """Get the patchwork ID of a series version 674 675 Args: 676 series_id (int): series ID to look up 677 version (int): version number to look up 678 679 Return: 680 tuple: 681 int: record id 682 str: link 683 """ 684 recs = self.get_ser_ver(series_id, version) 685 return recs.idnum, recs.link 686 687 def get_ser_ver(self, series_id, version): 688 """Get the patchwork details for a series version 689 690 Args: 691 series_id (int): series ID to look up 692 version (int): version number to look up 693 694 Return: 695 SER_VER: Requested information 696 697 Raises: 698 ValueError: There is no matching idnum/version 699 """ 700 return self.db.ser_ver_get_for_series(series_id, version) 701 702 def _prepare_process(self, name, count, new_name=None, quiet=False): 703 """Get ready to process all commits in a branch 704 705 Args: 706 name (str): Name of the branch to process 707 count (int): Number of commits 708 new_name (str or None): New name, if a new branch is to be created 709 quiet (bool): True to avoid output (used for testing) 710 711 Return: tuple: 712 pygit2.repo: Repo to use 713 pygit2.oid: Upstream commit, onto which commits should be added 714 Pygit2.branch: Original branch, for later use 715 str: (Possibly new) name of branch to process 716 list of Commit: commits to process, in order 717 pygit2.Reference: Original head before processing started 718 """ 719 upstream_guess = gitutil.get_upstream(self.gitdir, name)[0] 720 721 tout.debug(f"_process_series name '{name}' new_name '{new_name}' " 722 f"upstream_guess '{upstream_guess}'") 723 dirty = gitutil.check_dirty(self.gitdir, self.topdir) 724 if dirty: 725 raise ValueError( 726 f"Modified files exist: use 'git status' to check: " 727 f'{dirty[:5]}') 728 repo = pygit2.init_repository(self.gitdir) 729 730 commit = None 731 upstream_name = None 732 if upstream_guess: 733 try: 734 upstream = repo.lookup_reference(upstream_guess) 735 upstream_name = upstream.name 736 commit = upstream.peel(pygit2.enums.ObjectType.COMMIT) 737 except KeyError: 738 pass 739 except pygit2.repository.InvalidSpecError as exc: 740 print(f"Error '{exc}'") 741 if not upstream_name: 742 upstream_name = f'{name}~{count}' 743 commit = repo.revparse_single(upstream_name) 744 745 branch = repo.lookup_branch(name) 746 if not quiet: 747 tout.info( 748 f'Checking out upstream commit {upstream_name}: ' 749 f'{oid(commit.oid)}') 750 751 old_head = repo.head 752 if old_head.shorthand == name: 753 old_head = None 754 else: 755 old_head = repo.head 756 757 if new_name: 758 name = new_name 759 repo.set_head(commit.oid) 760 761 commits = [] 762 cmt = repo.get(branch.target) 763 for _ in range(count): 764 commits.append(cmt) 765 cmt = cmt.parents[0] 766 767 return (repo, repo.head, branch, name, commit, list(reversed(commits)), 768 old_head) 769 770 def _pick_commit(self, repo, cmt): 771 """Apply a commit to the source tree, without committing it 772 773 _prepare_process() must be called before starting to pick commits 774 775 This function must be called before _finish_commit() 776 777 Note that this uses a cherry-pick method, creating a new tree_id each 778 time, so can make source-code changes 779 780 Args: 781 repo (pygit2.repo): Repo to use 782 cmt (Commit): Commit to apply 783 784 Return: tuple: 785 tree_id (pygit2.oid): Oid of index with source-changes applied 786 commit (pygit2.oid): Old commit being cherry-picked 787 """ 788 tout.detail(f"- adding {oid(cmt.hash)} {cmt}") 789 repo.cherrypick(cmt.hash) 790 if repo.index.conflicts: 791 raise ValueError('Conflicts detected') 792 793 tree_id = repo.index.write_tree() 794 cherry = repo.get(cmt.hash) 795 tout.detail(f"cherry {oid(cherry.oid)}") 796 return tree_id, cherry 797 798 def _finish_commit(self, repo, tree_id, commit, cur, msg=None): 799 """Complete a commit 800 801 This must be called after _pick_commit(). 802 803 Args: 804 repo (pygit2.repo): Repo to use 805 tree_id (pygit2.oid): Oid of index with source-changes applied; if 806 None then the existing commit.tree_id is used 807 commit (pygit2.oid): Old commit being cherry-picked 808 cur (pygit2.reference): Reference to parent to use for the commit 809 msg (str): Commit subject and message; None to use commit.message 810 """ 811 if msg is None: 812 msg = commit.message 813 if not tree_id: 814 tree_id = commit.tree_id 815 repo.create_commit('HEAD', commit.author, commit.committer, 816 msg, tree_id, [cur.target]) 817 return repo.head 818 819 def _finish_process(self, repo, branch, name, cur, old_head, new_name=None, 820 switch=False, dry_run=False, quiet=False): 821 """Finish processing commits 822 823 Args: 824 repo (pygit2.repo): Repo to use 825 branch (pygit2.branch): Branch returned by _prepare_process() 826 name (str): Name of the branch to process 827 new_name (str or None): New name, if a new branch is being created 828 switch (bool): True to switch to the new branch after processing; 829 otherwise HEAD remains at the original branch, as amended 830 dry_run (bool): True to do a dry run, restoring the original tree 831 afterwards 832 quiet (bool): True to avoid output (used for testing) 833 834 Return: 835 pygit2.reference: Final commit after everything is completed 836 """ 837 repo.state_cleanup() 838 839 # Update the branch 840 target = repo.revparse_single('HEAD') 841 if not quiet: 842 tout.info(f'Updating branch {name} from {oid(branch.target)} to ' 843 f'{str(target.oid)[:HASH_LEN]}') 844 if dry_run: 845 if new_name: 846 repo.head.set_target(branch.target) 847 else: 848 branch_oid = branch.peel(pygit2.enums.ObjectType.COMMIT).oid 849 repo.head.set_target(branch_oid) 850 repo.head.set_target(branch.target) 851 repo.set_head(branch.name) 852 else: 853 if new_name: 854 new_branch = repo.branches.create(new_name, target) 855 if branch.upstream: 856 new_branch.upstream = branch.upstream 857 branch = new_branch 858 else: 859 branch.set_target(cur.target) 860 repo.set_head(branch.name) 861 if old_head: 862 if not switch: 863 repo.set_head(old_head.name) 864 return target 865 866 def make_change_id(self, commit): 867 """Make a Change ID for a commit 868 869 This is similar to the gerrit script: 870 git var GIT_COMMITTER_IDENT ; echo "$refhash" ; cat "README"; } 871 | git hash-object --stdin) 872 873 Args: 874 commit (pygit2.commit): Commit to process 875 876 Return: 877 Change ID in hex format 878 """ 879 sig = commit.committer 880 val = hashlib.sha1() 881 to_hash = f'{sig.name} <{sig.email}> {sig.time} {sig.offset}' 882 val.update(to_hash.encode('utf-8')) 883 val.update(str(commit.tree_id).encode('utf-8')) 884 val.update(commit.message.encode('utf-8')) 885 return val.hexdigest() 886 887 def _filter_commits(self, name, series, seq_to_drop): 888 """Filter commits to drop one 889 890 This function rebases the current branch, dropping a single commit, 891 thus changing the resulting code in the tree. 892 893 Args: 894 name (str): Name of the branch to process 895 series (Series): Series object 896 seq_to_drop (int): Commit sequence to drop; commits are numbered 897 from 0, which is the one after the upstream branch, to 898 count - 1 899 """ 900 count = len(series.commits) 901 (repo, cur, branch, name, commit, _, _) = self._prepare_process( 902 name, count, quiet=True) 903 repo.checkout_tree(commit, strategy=CheckoutStrategy.FORCE | 904 CheckoutStrategy.RECREATE_MISSING) 905 repo.set_head(commit.oid) 906 for seq, cmt in enumerate(series.commits): 907 if seq != seq_to_drop: 908 tree_id, cherry = self._pick_commit(repo, cmt) 909 cur = self._finish_commit(repo, tree_id, cherry, cur) 910 self._finish_process(repo, branch, name, cur, None, quiet=True) 911 912 def process_series(self, name, series, new_name=None, switch=False, 913 dry_run=False): 914 """Rewrite a series commit messages, leaving code alone 915 916 This uses a 'vals' namespace to pass things to the controlling 917 function. 918 919 Each time _process_series() yields, it sets up: 920 commit (Commit): The pygit2 commit that is being processed 921 msg (str): Commit message, which can be modified 922 info (str): Initially empty; the controlling function can add a 923 short message here which will be shown to the user 924 final (bool): True if this is the last commit to apply 925 seq (int): Current sequence number in the commits to apply (0,,n-1) 926 927 It also sets git HEAD at the commit before this commit being 928 processed 929 930 The function can change msg and info, e.g. to add or remove tags from 931 the commit. 932 933 Args: 934 name (str): Name of the branch to process 935 series (Series): Series object 936 new_name (str or None): New name, if a new branch is to be created 937 switch (bool): True to switch to the new branch after processing; 938 otherwise HEAD remains at the original branch, as amended 939 dry_run (bool): True to do a dry run, restoring the original tree 940 afterwards 941 942 Return: 943 pygit.oid: oid of the new branch 944 """ 945 count = len(series.commits) 946 repo, cur, branch, name, _, commits, old_head = self._prepare_process( 947 name, count, new_name) 948 vals = SimpleNamespace() 949 vals.final = False 950 tout.info(f"Processing {count} commits from branch '{name}'") 951 952 # Record the message lines 953 lines = [] 954 for seq, cmt in enumerate(series.commits): 955 commit = commits[seq] 956 vals.commit = commit 957 vals.msg = commit.message 958 vals.info = '' 959 vals.final = seq == len(series.commits) - 1 960 vals.seq = seq 961 yield vals 962 963 cur = self._finish_commit(repo, None, commit, cur, vals.msg) 964 lines.append([vals.info.strip(), 965 f'{oid(cmt.hash)} as {oid(cur.target)} {cmt}']) 966 967 max_len = max(len(info) for info, rest in lines) + 1 968 for info, rest in lines: 969 if info: 970 info += ':' 971 tout.info(f'- {info.ljust(max_len)} {rest}') 972 target = self._finish_process(repo, branch, name, cur, old_head, 973 new_name, switch, dry_run) 974 vals.oid = target.oid 975 976 def _mark_series(self, name, series, dry_run=False): 977 """Mark a series with Change-Id tags 978 979 Args: 980 name (str): Name of the series to mark 981 series (Series): Series object 982 dry_run (bool): True to do a dry run, restoring the original tree 983 afterwards 984 985 Return: 986 pygit.oid: oid of the new branch 987 """ 988 vals = None 989 for vals in self.process_series(name, series, dry_run=dry_run): 990 if CHANGE_ID_TAG not in vals.msg: 991 change_id = self.make_change_id(vals.commit) 992 vals.msg = vals.msg + f'\n{CHANGE_ID_TAG}: {change_id}' 993 tout.detail(" - adding mark") 994 vals.info = 'marked' 995 else: 996 vals.info = 'has mark' 997 998 return vals.oid 999 1000 def update_series(self, branch_name, series, max_vers, new_name=None, 1001 dry_run=False, add_vers=None, add_link=None, 1002 add_rtags=None, switch=False): 1003 """Rewrite a series to update the Series-version/Series-links lines 1004 1005 This updates the series in git; it does not update the database 1006 1007 Args: 1008 branch_name (str): Name of the branch to process 1009 series (Series): Series object 1010 max_vers (int): Version number of the series being updated 1011 new_name (str or None): New name, if a new branch is to be created 1012 dry_run (bool): True to do a dry run, restoring the original tree 1013 afterwards 1014 add_vers (int or None): Version number to add to the series, if any 1015 add_link (str or None): Link to add to the series, if any 1016 add_rtags (list of dict): List of review tags to add, one item for 1017 each commit, each a dict: 1018 key: Response tag (e.g. 'Reviewed-by') 1019 value: Set of people who gave that response, each a name/email 1020 string 1021 switch (bool): True to switch to the new branch after processing; 1022 otherwise HEAD remains at the original branch, as amended 1023 1024 Return: 1025 pygit.oid: oid of the new branch 1026 """ 1027 def _do_version(): 1028 if add_vers: 1029 if add_vers == 1: 1030 vals.info += f'rm v{add_vers} ' 1031 else: 1032 vals.info += f'add v{add_vers} ' 1033 out.append(f'Series-version: {add_vers}') 1034 1035 def _do_links(new_links): 1036 if add_link: 1037 if 'add' not in vals.info: 1038 vals.info += 'add ' 1039 vals.info += f"links '{new_links}' " 1040 else: 1041 vals.info += f"upd links '{new_links}' " 1042 out.append(f'Series-links: {new_links}') 1043 1044 added_version = False 1045 added_link = False 1046 for vals in self.process_series(branch_name, series, new_name, switch, 1047 dry_run): 1048 out = [] 1049 for line in vals.msg.splitlines(): 1050 m_ver = re.match('Series-version:(.*)', line) 1051 m_links = re.match('Series-links:(.*)', line) 1052 if m_ver and add_vers: 1053 if ('version' in series and 1054 int(series.version) != max_vers): 1055 tout.warning( 1056 f'Branch {branch_name}: Series-version tag ' 1057 f'{series.version} does not match expected ' 1058 f'version {max_vers}') 1059 _do_version() 1060 added_version = True 1061 elif m_links: 1062 links = series.get_links(m_links.group(1), max_vers) 1063 if add_link: 1064 links[max_vers] = add_link 1065 _do_links(series.build_links(links)) 1066 added_link = True 1067 else: 1068 out.append(line) 1069 if vals.final: 1070 if not added_version and add_vers and add_vers > 1: 1071 _do_version() 1072 if not added_link and add_link: 1073 _do_links(f'{max_vers}:{add_link}') 1074 1075 vals.msg = '\n'.join(out) + '\n' 1076 if add_rtags and add_rtags[vals.seq]: 1077 lines = [] 1078 for tag, people in add_rtags[vals.seq].items(): 1079 for who in people: 1080 lines.append(f'{tag}: {who}') 1081 vals.msg = patchstream.insert_tags(vals.msg.rstrip(), 1082 sorted(lines)) 1083 vals.info += (f'added {len(lines)} ' 1084 f"tag{'' if len(lines) == 1 else 's'}") 1085 1086 def _build_col(self, state, prefix='', base_str=None): 1087 """Build a patch-state string with colour 1088 1089 Args: 1090 state (str): State to colourise (also indicates the colour to use) 1091 prefix (str): Prefix string to also colourise 1092 base_str (str or None): String to show instead of state, or None to 1093 show state 1094 1095 Return: 1096 str: String with ANSI colour characters 1097 """ 1098 bright = True 1099 if state == 'accepted': 1100 col = self.col.GREEN 1101 elif state == 'awaiting-upstream': 1102 bright = False 1103 col = self.col.GREEN 1104 elif state in ['changes-requested']: 1105 col = self.col.CYAN 1106 elif state in ['rejected', 'deferred', 'not-applicable', 'superseded', 1107 'handled-elsewhere']: 1108 col = self.col.RED 1109 elif not state: 1110 state = 'unknown' 1111 col = self.col.MAGENTA 1112 else: 1113 # under-review, rfc, needs-review-ack 1114 col = self.col.WHITE 1115 out = base_str or SHORTEN_STATE.get(state, state) 1116 pad = ' ' * (10 - len(out)) 1117 col_state = self.col.build(col, prefix + out, bright) 1118 return col_state, pad 1119 1120 def _get_patches(self, series, version): 1121 """Get a Series object containing the patches in a series 1122 1123 Args: 1124 series (str): Name of series to use, or None to use current branch 1125 version (int): Version number, or None to detect from name 1126 1127 Return: tuple: 1128 str: Name of branch, e.g. 'mary2' 1129 Series: Series object containing the commits and idnum, desc, name 1130 int: Version number of series, e.g. 2 1131 OrderedDict: 1132 key (int): record ID if find_svid is None, else seq 1133 value (PCOMMIT): record data 1134 str: series name (for this version) 1135 str: patchwork link 1136 str: cover_id 1137 int: cover_num_comments 1138 """ 1139 ser, version = self._parse_series_and_version(series, version) 1140 if not ser.idnum: 1141 raise ValueError(f"Unknown series '{series}'") 1142 self._ensure_version(ser, version) 1143 svinfo = self.get_ser_ver(ser.idnum, version) 1144 pwc = self.get_pcommit_dict(svinfo.idnum) 1145 1146 count = len(pwc) 1147 branch = self._join_name_version(ser.name, version) 1148 series = patchstream.get_metadata(branch, 0, count, 1149 git_dir=self.gitdir) 1150 self._copy_db_fields_to(series, ser) 1151 1152 return (branch, series, version, pwc, svinfo.name, svinfo.link, 1153 svinfo.cover_id, svinfo.cover_num_comments) 1154 1155 def _list_patches(self, branch, pwc, series, desc, cover_id, num_comments, 1156 show_commit, show_patch, list_patches, state_totals): 1157 """List patches along with optional status info 1158 1159 Args: 1160 branch (str): Branch name if self.show_progress 1161 pwc (dict): pcommit records: 1162 key (int): seq 1163 value (PCOMMIT): Record from database 1164 series (Series): Series to show, or None to just use the database 1165 desc (str): Series title 1166 cover_id (int): Cover-letter ID 1167 num_comments (int): The number of comments on the cover letter 1168 show_commit (bool): True to show the commit and diffstate 1169 show_patch (bool): True to show the patch 1170 list_patches (bool): True to list all patches for each series, 1171 False to just show the series summary on a single line 1172 state_totals (dict): Holds totals for each state across all patches 1173 key (str): state name 1174 value (int): Number of patches in that state 1175 1176 Return: 1177 bool: True if OK, False if any commit subjects don't match their 1178 patchwork subjects 1179 """ 1180 lines = [] 1181 states = defaultdict(int) 1182 count = len(pwc) 1183 ok = True 1184 for seq, item in enumerate(pwc.values()): 1185 if series: 1186 cmt = series.commits[seq] 1187 if cmt.subject != item.subject: 1188 ok = False 1189 1190 col_state, pad = self._build_col(item.state) 1191 patch_id = item.patch_id if item.patch_id else '' 1192 if item.num_comments: 1193 comments = str(item.num_comments) 1194 elif item.num_comments is None: 1195 comments = '-' 1196 else: 1197 comments = '' 1198 1199 if show_commit or show_patch: 1200 subject = self.col.build(self.col.BLACK, item.subject, 1201 bright=False, back=self.col.YELLOW) 1202 else: 1203 subject = item.subject 1204 1205 line = (f'{seq:3} {col_state}{pad} {comments.rjust(3)} ' 1206 f'{patch_id:7} {oid(cmt.hash)} {subject}') 1207 lines.append(line) 1208 states[item.state] += 1 1209 out = '' 1210 for state, freq in states.items(): 1211 out += ' ' + self._build_col(state, f'{freq}:')[0] 1212 state_totals[state] += freq 1213 name = '' 1214 if not list_patches: 1215 name = desc or series.desc 1216 name = self.col.build(self.col.YELLOW, name[:41].ljust(41)) 1217 if not ok: 1218 out = '*' + out[1:] 1219 print(f"{branch:16} {name} {len(pwc):5} {out}") 1220 return ok 1221 print(f"Branch '{branch}' (total {len(pwc)}):{out}{name}") 1222 1223 print(self.col.build( 1224 self.col.MAGENTA, 1225 f"Seq State Com PatchId {'Commit'.ljust(HASH_LEN)} Subject")) 1226 1227 comments = '' if num_comments is None else str(num_comments) 1228 if desc or comments or cover_id: 1229 cov = 'Cov' if cover_id else '' 1230 print(self.col.build( 1231 self.col.WHITE, 1232 f"{cov:14} {comments.rjust(3)} {cover_id or '':7} " 1233 f'{desc or series.desc}', 1234 bright=False)) 1235 for seq in range(count): 1236 line = lines[seq] 1237 print(line) 1238 if show_commit or show_patch: 1239 print() 1240 cmt = series.commits[seq] if series else '' 1241 msg = gitutil.show_commit( 1242 cmt.hash, show_commit, True, show_patch, 1243 colour=self.col.enabled(), git_dir=self.gitdir) 1244 sys.stdout.write(msg) 1245 if seq != count - 1: 1246 print() 1247 print() 1248 1249 return ok 1250 1251 def _find_matched_commit(self, commits, pcm): 1252 """Find a commit in a list of possible matches 1253 1254 Args: 1255 commits (dict of Commit): Possible matches 1256 key (int): sequence number of patch (from 0) 1257 value (Commit): Commit object 1258 pcm (PCOMMIT): Patch to check 1259 1260 Return: 1261 int: Sequence number of matching commit, or None if not found 1262 """ 1263 for seq, cmt in commits.items(): 1264 tout.debug(f"- match subject: '{cmt.subject}'") 1265 if pcm.subject == cmt.subject: 1266 return seq 1267 return None 1268 1269 def _find_matched_patch(self, patches, cmt): 1270 """Find a patch in a list of possible matches 1271 1272 Args: 1273 patches: dict of ossible matches 1274 key (int): sequence number of patch 1275 value (PCOMMIT): patch 1276 cmt (Commit): Commit to check 1277 1278 Return: 1279 int: Sequence number of matching patch, or None if not found 1280 """ 1281 for seq, pcm in patches.items(): 1282 tout.debug(f"- match subject: '{pcm.subject}'") 1283 if cmt.subject == pcm.subject: 1284 return seq 1285 return None 1286 1287 def _sync_one(self, svid, series_name, version, show_comments, 1288 show_cover_comments, gather_tags, cover, patches, dry_run): 1289 """Sync one series to the database 1290 1291 Args: 1292 svid (int): Ser/ver ID 1293 cover (dict or None): Cover letter from patchwork, with keys: 1294 id (int): Cover-letter ID in patchwork 1295 num_comments (int): Number of comments 1296 name (str): Cover-letter name 1297 patches (list of Patch): Patches in the series 1298 """ 1299 pwc = self.get_pcommit_dict(svid) 1300 if gather_tags: 1301 count = len(pwc) 1302 branch = self._join_name_version(series_name, version) 1303 series = patchstream.get_metadata(branch, 0, count, 1304 git_dir=self.gitdir) 1305 1306 _, new_rtag_list = status.do_show_status( 1307 series, cover, patches, show_comments, show_cover_comments, 1308 self.col, warnings_on_stderr=False) 1309 self.update_series(branch, series, version, None, dry_run, 1310 add_rtags=new_rtag_list) 1311 1312 updated = 0 1313 for seq, item in enumerate(pwc.values()): 1314 if seq >= len(patches): 1315 continue 1316 patch = patches[seq] 1317 if patch.id: 1318 if self.db.pcommit_update( 1319 Pcommit(item.idnum, seq, None, None, None, patch.state, 1320 patch.id, len(patch.comments))): 1321 updated += 1 1322 if cover: 1323 info = SerVer(svid, None, None, None, cover.id, 1324 cover.num_comments, cover.name, None) 1325 else: 1326 info = SerVer(svid, None, None, None, None, None, patches[0].name, 1327 None) 1328 self.db.ser_ver_set_info(info) 1329 1330 return updated, 1 if cover else 0 1331 1332 async def _gather(self, pwork, link, show_cover_comments): 1333 """Sync the series status from patchwork 1334 1335 Creates a new client sesion and calls _sync() 1336 1337 Args: 1338 pwork (Patchwork): Patchwork object to use 1339 link (str): Patchwork link for the series 1340 show_cover_comments (bool): True to show the comments on the cover 1341 letter 1342 1343 Return: tuple: 1344 COVER object, or None if none or not read_cover_comments 1345 list of PATCH objects 1346 """ 1347 async with aiohttp.ClientSession() as client: 1348 return await pwork.series_get_state(client, link, True, 1349 show_cover_comments) 1350 1351 def _get_fetch_dict(self, sync_all_versions): 1352 """Get a dict of ser_vers to fetch, along with their patchwork links 1353 1354 Args: 1355 sync_all_versions (bool): True to sync all versions of a series, 1356 False to sync only the latest version 1357 1358 Return: tuple: 1359 dict: things to fetch 1360 key (int): svid 1361 value (str): patchwork link for the series 1362 int: number of series which are missing a link 1363 """ 1364 missing = 0 1365 svdict = self.get_ser_ver_dict() 1366 sdict = self.db.series_get_dict_by_id() 1367 to_fetch = {} 1368 1369 if sync_all_versions: 1370 for svinfo in self.get_ser_ver_list(): 1371 ser_ver = svdict[svinfo.idnum] 1372 if svinfo.link: 1373 to_fetch[svinfo.idnum] = patchwork.STATE_REQ( 1374 svinfo.link, svinfo.series_id, 1375 sdict[svinfo.series_id].name, svinfo.version, False, 1376 False) 1377 else: 1378 missing += 1 1379 else: 1380 # Find the maximum version for each series 1381 max_vers = self._series_all_max_versions() 1382 1383 # Get a list of links to fetch 1384 for svid, series_id, version in max_vers: 1385 ser_ver = svdict[svid] 1386 if series_id not in sdict: 1387 # skip archived item 1388 continue 1389 if ser_ver.link: 1390 to_fetch[svid] = patchwork.STATE_REQ( 1391 ser_ver.link, series_id, sdict[series_id].name, 1392 version, False, False) 1393 else: 1394 missing += 1 1395 1396 # order by series name, version 1397 ordered = OrderedDict() 1398 for svid in sorted( 1399 to_fetch, 1400 key=lambda k: (to_fetch[k].series_name, to_fetch[k].version)): 1401 sync = to_fetch[svid] 1402 ordered[svid] = sync 1403 1404 return ordered, missing 1405 1406 async def _sync_all(self, client, pwork, to_fetch): 1407 """Sync all series status from patchwork 1408 1409 Args: 1410 pwork (Patchwork): Patchwork object to use 1411 sync_all_versions (bool): True to sync all versions of a series, 1412 False to sync only the latest version 1413 gather_tags (bool): True to gather review/test tags 1414 1415 Return: list of tuple: 1416 COVER object, or None if none or not read_cover_comments 1417 list of PATCH objects 1418 """ 1419 with pwork.collect_stats() as stats: 1420 tasks = [pwork.series_get_state(client, sync.link, True, True) 1421 for sync in to_fetch.values() if sync.link] 1422 result = await asyncio.gather(*tasks) 1423 return result, stats.request_count 1424 1425 async def _do_series_sync_all(self, pwork, to_fetch): 1426 async with aiohttp.ClientSession() as client: 1427 return await self._sync_all(client, pwork, to_fetch) 1428 1429 def _progress_one(self, ser, show_all_versions, list_patches, 1430 state_totals): 1431 """Show progress information for all versions in a series 1432 1433 Args: 1434 ser (Series): Series to use 1435 show_all_versions (bool): True to show all versions of a series, 1436 False to show only the final version 1437 list_patches (bool): True to list all patches for each series, 1438 False to just show the series summary on a single line 1439 state_totals (dict): Holds totals for each state across all patches 1440 key (str): state name 1441 value (int): Number of patches in that state 1442 1443 Return: tuple 1444 int: Number of series shown 1445 int: Number of patches shown 1446 int: Number of version which need a 'scan' 1447 """ 1448 max_vers = self._series_max_version(ser.idnum) 1449 name, desc = self._get_series_info(ser.idnum) 1450 coloured = self.col.build(self.col.BLACK, desc, bright=False, 1451 back=self.col.YELLOW) 1452 versions = self._get_version_list(ser.idnum) 1453 vstr = list(map(str, versions)) 1454 1455 if list_patches: 1456 print(f"{name}: {coloured} (versions: {' '.join(vstr)})") 1457 add_blank_line = False 1458 total_series = 0 1459 total_patches = 0 1460 need_scan = 0 1461 for ver in versions: 1462 if not show_all_versions and ver != max_vers: 1463 continue 1464 if add_blank_line: 1465 print() 1466 _, pwc = self._series_get_version_stats(ser.idnum, ver) 1467 count = len(pwc) 1468 branch = self._join_name_version(ser.name, ver) 1469 series = patchstream.get_metadata(branch, 0, count, 1470 git_dir=self.gitdir) 1471 svinfo = self.get_ser_ver(ser.idnum, ver) 1472 self._copy_db_fields_to(series, ser) 1473 1474 ok = self._list_patches( 1475 branch, pwc, series, svinfo.name, svinfo.cover_id, 1476 svinfo.cover_num_comments, False, False, list_patches, 1477 state_totals) 1478 if not ok: 1479 need_scan += 1 1480 add_blank_line = list_patches 1481 total_series += 1 1482 total_patches += count 1483 return total_series, total_patches, need_scan 1484 1485 def _summary_one(self, ser): 1486 """Show summary information for the latest version in a series 1487 1488 Args: 1489 series (str): Name of series to use, or None to show progress for 1490 all series 1491 """ 1492 max_vers = self._series_max_version(ser.idnum) 1493 name, desc = self._get_series_info(ser.idnum) 1494 stats, pwc = self._series_get_version_stats(ser.idnum, max_vers) 1495 states = {x.state for x in pwc.values()} 1496 state = 'accepted' 1497 for val in ['awaiting-upstream', 'changes-requested', 'rejected', 1498 'deferred', 'not-applicable', 'superseded', 1499 'handled-elsewhere']: 1500 if val in states: 1501 state = val 1502 state_str, pad = self._build_col(state, base_str=name) 1503 print(f"{state_str}{pad} {stats.rjust(6)} {desc}") 1504 1505 def _series_max_version(self, idnum): 1506 """Find the latest version of a series 1507 1508 Args: 1509 idnum (int): Series ID to look up 1510 1511 Return: 1512 int: maximum version 1513 """ 1514 return self.db.series_get_max_version(idnum) 1515 1516 def _series_all_max_versions(self): 1517 """Find the latest version of all series 1518 1519 Return: list of: 1520 int: ser_ver ID 1521 int: series ID 1522 int: Maximum version 1523 """ 1524 return self.db.series_get_all_max_versions() 1525