1# SPDX-License-Identifier: GPL-2.0+
2#
3# Copyright 2025 Simon Glass <sjg@chromium.org>
4#
5"""Helper functions for handling the 'series' subcommand
6"""
7
8import asyncio
9from collections import OrderedDict, defaultdict, namedtuple
10from datetime import datetime
11import hashlib
12import os
13import re
14import sys
15import time
16from types import SimpleNamespace
17
18import aiohttp
19import pygit2
20from pygit2.enums import CheckoutStrategy
21
22from u_boot_pylib import gitutil
23from u_boot_pylib import terminal
24from u_boot_pylib import tout
25
26from patman import patchstream
27from patman.database import Database, Pcommit, SerVer
28from patman import patchwork
29from patman.series import Series
30from patman import status
31
32
33# Tag to use for Change IDs
34CHANGE_ID_TAG = 'Change-Id'
35
36# Length of hash to display
37HASH_LEN = 10
38
39# Shorter version of some states, to save horizontal space
40SHORTEN_STATE = {
41    'handled-elsewhere': 'elsewhere',
42    'awaiting-upstream': 'awaiting',
43    'not-applicable': 'n/a',
44    'changes-requested': 'changes',
45}
46
47# Summary info returned from Cseries.link_auto_all()
48AUTOLINK = namedtuple('autolink', 'name,version,link,desc,result')
49
50
51def oid(oid_val):
52    """Convert a hash string into a shortened hash
53
54    The number of hex digits git uses for showing hashes depends on the size of
55    the repo. For the purposes of showing hashes to the user in lists, we use a
56    fixed value for now
57
58    Args:
59        str or Pygit2.oid: Hash value to shorten
60
61    Return:
62        str: Shortened hash
63    """
64    return str(oid_val)[:HASH_LEN]
65
66
67def split_name_version(in_name):
68    """Split a branch name into its series name and its version
69
70    For example:
71        'series' returns ('series', 1)
72        'series3' returns ('series', 3)
73    Args:
74        in_name (str): Name to parse
75
76    Return:
77        tuple:
78            str: series name
79            int: series version, or None if there is none in in_name
80    """
81    m_ver = re.match(r'([^0-9]*)(\d*)', in_name)
82    version = None
83    if m_ver:
84        name = m_ver.group(1)
85        if m_ver.group(2):
86            version = int(m_ver.group(2))
87    else:
88        name = in_name
89    return name, version
90
91
92class CseriesHelper:
93    """Helper functions for Cseries
94
95    This class handles database read/write as well as operations in a git
96    directory to update series information.
97    """
98    def __init__(self, topdir=None, colour=terminal.COLOR_IF_TERMINAL):
99        """Set up a new CseriesHelper
100
101        Args:
102            topdir (str): Top-level directory of the repo
103            colour (terminal.enum): Whether to enable ANSI colour or not
104
105        Properties:
106            gitdir (str): Git directory (typically topdir + '/.git')
107            db (Database): Database handler
108            col (terminal.Colour): Colour object
109            _fake_time (float): Holds the current fake time for tests, in
110                seconds
111            _fake_sleep (func): Function provided by a test; called to fake a
112                'time.sleep()' call and take whatever action it wants to take.
113                The only argument is the (Float) time to sleep for; it returns
114                nothing
115            loop (asyncio event loop): Loop used for Patchwork operations
116        """
117        self.topdir = topdir
118        self.gitdir = None
119        self.db = None
120        self.col = terminal.Color(colour)
121        self._fake_time = None
122        self._fake_sleep = None
123        self.fake_now = None
124        self.loop = asyncio.get_event_loop()
125
126    def open_database(self):
127        """Open the database ready for use"""
128        if not self.topdir:
129            self.topdir = gitutil.get_top_level()
130            if not self.topdir:
131                raise ValueError('No git repo detected in current directory')
132        self.gitdir = os.path.join(self.topdir, '.git')
133        fname = f'{self.topdir}/.patman.db'
134
135        # For the first instance, start it up with the expected schema
136        self.db, is_new = Database.get_instance(fname)
137        if is_new:
138            self.db.start()
139        else:
140            # If a previous test has already checked the schema, just open it
141            self.db.open_it()
142
143    def close_database(self):
144        """Close the database"""
145        if self.db:
146            self.db.close()
147
148    def commit(self):
149        """Commit changes to the database"""
150        self.db.commit()
151
152    def rollback(self):
153        """Roll back changes to the database"""
154        self.db.rollback()
155
156    def set_fake_time(self, fake_sleep):
157        """Setup the fake timer
158
159        Args:
160            fake_sleep (func(float)): Function to call to fake a sleep
161        """
162        self._fake_time = 0
163        self._fake_sleep = fake_sleep
164
165    def inc_fake_time(self, inc_s):
166        """Increment the fake time
167
168        Args:
169            inc_s (float): Amount to increment the fake time by
170        """
171        self._fake_time += inc_s
172
173    def get_time(self):
174        """Get the current time, fake or real
175
176        This function should always be used to read the time so that faking the
177        time works correctly in tests.
178
179        Return:
180            float: Fake time, if time is being faked, else real time
181        """
182        if self._fake_time is not None:
183            return self._fake_time
184        return time.monotonic()
185
186    def sleep(self, time_s):
187        """Sleep for a while
188
189        This function should always be used to sleep so that faking the time
190        works correctly in tests.
191
192        Args:
193            time_s (float): Amount of seconds to sleep for
194        """
195        print(f'Sleeping for {time_s} seconds')
196        if self._fake_time is not None:
197            self._fake_sleep(time_s)
198        else:
199            time.sleep(time_s)
200
201    def get_now(self):
202        """Get the time now
203
204        This function should always be used to read the datetime, so that
205        faking the time works correctly in tests
206
207        Return:
208            DateTime object
209        """
210        if self.fake_now:
211            return self.fake_now
212        return datetime.now()
213
214    def get_ser_ver_list(self):
215        """Get a list of patchwork entries from the database
216
217        Return:
218            list of SER_VER
219        """
220        return self.db.ser_ver_get_list()
221
222    def get_ser_ver_dict(self):
223        """Get a dict of patchwork entries from the database
224
225        Return: dict contain all records:
226            key (int): ser_ver id
227            value (SER_VER): Information about one ser_ver record
228        """
229        svlist = self.get_ser_ver_list()
230        svdict = {}
231        for sver in svlist:
232            svdict[sver.idnum] = sver
233        return svdict
234
235    def get_upstream_dict(self):
236        """Get a list of upstream entries from the database
237
238        Return:
239            OrderedDict:
240                key (str): upstream name
241                value (str): url
242        """
243        return self.db.upstream_get_dict()
244
245    def get_pcommit_dict(self, find_svid=None):
246        """Get a dict of pcommits entries from the database
247
248        Args:
249            find_svid (int): If not None, finds the records associated with a
250                particular series and version
251
252        Return:
253            OrderedDict:
254                key (int): record ID if find_svid is None, else seq
255                value (PCOMMIT): record data
256        """
257        pcdict = OrderedDict()
258        for rec in self.db.pcommit_get_list(find_svid):
259            if find_svid is not None:
260                pcdict[rec.seq] = rec
261            else:
262                pcdict[rec.idnum] = rec
263        return pcdict
264
265    def _get_series_info(self, idnum):
266        """Get information for a series from the database
267
268        Args:
269            idnum (int): Series ID to look up
270
271        Return: tuple:
272            str: Series name
273            str: Series description
274
275        Raises:
276            ValueError: Series is not found
277        """
278        return self.db.series_get_info(idnum)
279
280    def prep_series(self, name, end=None):
281        """Prepare to work with a series
282
283        Args:
284            name (str): Branch name with version appended, e.g. 'fix2'
285            end (str or None): Commit to end at, e.g. 'my_branch~16'. Only
286                commits up to that are processed. None to process commits up to
287                the upstream branch
288
289        Return: tuple:
290            str: Series name, e.g. 'fix'
291            Series: Collected series information, including name
292            int: Version number, e.g. 2
293            str: Message to show
294        """
295        ser, version = self._parse_series_and_version(name, None)
296        if not name:
297            name = self._get_branch_name(ser.name, version)
298
299        # First check we have a branch with this name
300        if not gitutil.check_branch(name, git_dir=self.gitdir):
301            raise ValueError(f"No branch named '{name}'")
302
303        count = gitutil.count_commits_to_branch(name, self.gitdir, end)
304        if not count:
305            raise ValueError('Cannot detect branch automatically: '
306                             'Perhaps use -U <upstream-commit> ?')
307
308        series = patchstream.get_metadata(name, 0, count, git_dir=self.gitdir)
309        self._copy_db_fields_to(series, ser)
310        msg = None
311        if end:
312            repo = pygit2.init_repository(self.gitdir)
313            target = repo.revparse_single(end)
314            first_line = target.message.splitlines()[0]
315            msg = f'Ending before {oid(target.id)} {first_line}'
316
317        return name, series, version, msg
318
319    def _copy_db_fields_to(self, series, in_series):
320        """Copy over fields used by Cseries from one series to another
321
322        This copes desc, idnum and name
323
324        Args:
325            series (Series): Series to copy to
326            in_series (Series): Series to copy from
327        """
328        series.desc = in_series.desc
329        series.idnum = in_series.idnum
330        series.name = in_series.name
331
332    def _handle_mark(self, branch_name, in_series, version, mark,
333                     allow_unmarked, force_version, dry_run):
334        """Handle marking a series, checking for unmarked commits, etc.
335
336        Args:
337            branch_name (str): Name of branch to sync, or None for current one
338            in_series (Series): Series object
339            version (int): branch version, e.g. 2 for 'mychange2'
340            mark (bool): True to mark each commit with a change ID
341            allow_unmarked (str): True to not require each commit to be marked
342            force_version (bool): True if ignore a Series-version tag that
343                doesn't match its branch name
344            dry_run (bool): True to do a dry run
345
346        Returns:
347            Series: New series object, if the series was marked;
348                copy_db_fields_to() is used to copy fields over
349
350        Raises:
351            ValueError: Series being unmarked when it should be marked, etc.
352        """
353        series = in_series
354        if 'version' in series and int(series.version) != version:
355            msg = (f"Series name '{branch_name}' suggests version {version} "
356                   f"but Series-version tag indicates {series.version}")
357            if not force_version:
358                raise ValueError(msg + ' (see --force-version)')
359
360            tout.warning(msg)
361            tout.warning(f'Updating Series-version tag to version {version}')
362            self.update_series(branch_name, series, int(series.version),
363                                new_name=None, dry_run=dry_run,
364                                add_vers=version)
365
366            # Collect the commits again, as the hashes have changed
367            series = patchstream.get_metadata(branch_name, 0,
368                                              len(series.commits),
369                                              git_dir=self.gitdir)
370            self._copy_db_fields_to(series, in_series)
371
372        if mark:
373            add_oid = self._mark_series(branch_name, series, dry_run=dry_run)
374
375            # Collect the commits again, as the hashes have changed
376            series = patchstream.get_metadata(add_oid, 0, len(series.commits),
377                                              git_dir=self.gitdir)
378            self._copy_db_fields_to(series, in_series)
379
380        bad_count = 0
381        for commit in series.commits:
382            if not commit.change_id:
383                bad_count += 1
384        if bad_count and not allow_unmarked:
385            raise ValueError(
386                f'{bad_count} commit(s) are unmarked; please use -m or -M')
387
388        return series
389
390    def _add_series_commits(self, series, svid):
391        """Add a commits from a series into the database
392
393        Args:
394            series (Series): Series containing commits to add
395            svid (int): ser_ver-table ID to use for each commit
396        """
397        to_add = [Pcommit(None, seq, commit.subject, None, commit.change_id,
398                          None, None, None)
399                  for seq, commit in enumerate(series.commits)]
400
401        self.db.pcommit_add_list(svid, to_add)
402
403    def get_series_by_name(self, name, include_archived=False):
404        """Get a Series object from the database by name
405
406        Args:
407            name (str): Name of series to get
408            include_archived (bool): True to search in archives series
409
410        Return:
411            Series: Object containing series info, or None if none
412        """
413        idnum = self.db.series_find_by_name(name, include_archived)
414        if not idnum:
415            return None
416        name, desc = self.db.series_get_info(idnum)
417
418        return Series.from_fields(idnum, name, desc)
419
420    def _get_branch_name(self, name, version):
421        """Get the branch name for a particular version
422
423        Args:
424            name (str): Base name of branch
425            version (int): Version number to use
426        """
427        return name + (f'{version}' if version > 1 else '')
428
429    def _ensure_version(self, ser, version):
430        """Ensure that a version exists in a series
431
432        Args:
433            ser (Series): Series information, with idnum and name used here
434            version (int): Version to check
435
436        Returns:
437            list of int: List of versions
438        """
439        versions = self._get_version_list(ser.idnum)
440        if version not in versions:
441            raise ValueError(
442                f"Series '{ser.name}' does not have a version {version}")
443        return versions
444
445    def _set_link(self, ser_id, name, version, link, update_commit,
446                  dry_run=False):
447        """Add / update a series-links link for a series
448
449        Args:
450            ser_id (int): Series ID number
451            name (str): Series name (used to find the branch)
452            version (int): Version number (used to update the database)
453            link (str): Patchwork link-string for the series
454            update_commit (bool): True to update the current commit with the
455                link
456            dry_run (bool): True to do a dry run
457
458        Return:
459            bool: True if the database was update, False if the ser_id or
460                version was not found
461        """
462        if update_commit:
463            branch_name = self._get_branch_name(name, version)
464            _, ser, max_vers, _ = self.prep_series(branch_name)
465            self.update_series(branch_name, ser, max_vers, add_vers=version,
466                                dry_run=dry_run, add_link=link)
467        if link is None:
468            link = ''
469        updated = 1 if self.db.ser_ver_set_link(ser_id, version, link) else 0
470        if dry_run:
471            self.rollback()
472        else:
473            self.commit()
474
475        return updated
476
477    def _get_autolink_dict(self, sdict, link_all_versions):
478        """Get a dict of ser_vers to fetch, along with their patchwork links
479
480        Note that this returns items that already have links, as well as those
481        without links
482
483        Args:
484            sdict:
485                key: series ID
486                value: Series with idnum, name and desc filled out
487            link_all_versions (bool): True to sync all versions of a series,
488                False to sync only the latest version
489
490        Return: tuple:
491            dict:
492                key (int): svid
493                value (tuple):
494                   int: series ID
495                   str: series name
496                   int: series version
497                   str: patchwork link for the series, or None if none
498                   desc: cover-letter name / series description
499        """
500        svdict = self.get_ser_ver_dict()
501        to_fetch = {}
502
503        if link_all_versions:
504            for svinfo in self.get_ser_ver_list():
505                ser = sdict[svinfo.series_id]
506
507                pwc = self.get_pcommit_dict(svinfo.idnum)
508                count = len(pwc)
509                branch = self._join_name_version(ser.name, svinfo.version)
510                series = patchstream.get_metadata(branch, 0, count,
511                                                  git_dir=self.gitdir)
512                self._copy_db_fields_to(series, ser)
513
514                to_fetch[svinfo.idnum] = (svinfo.series_id, series.name,
515                                          svinfo.version, svinfo.link, series)
516        else:
517            # Find the maximum version for each series
518            max_vers = self._series_all_max_versions()
519
520            # Get a list of links to fetch
521            for svid, ser_id, version in max_vers:
522                svinfo = svdict[svid]
523                ser = sdict[ser_id]
524
525                pwc = self.get_pcommit_dict(svid)
526                count = len(pwc)
527                branch = self._join_name_version(ser.name, version)
528                series = patchstream.get_metadata(branch, 0, count,
529                                                  git_dir=self.gitdir)
530                self._copy_db_fields_to(series, ser)
531
532                to_fetch[svid] = (ser_id, series.name, version, svinfo.link,
533                                  series)
534        return to_fetch
535
536    def _get_version_list(self, idnum):
537        """Get a list of the versions available for a series
538
539        Args:
540            idnum (int): ID of series to look up
541
542        Return:
543            str: List of versions
544        """
545        if idnum is None:
546            raise ValueError('Unknown series idnum')
547        return self.db.series_get_version_list(idnum)
548
549    def _join_name_version(self, in_name, version):
550        """Convert a series name plus a version into a branch name
551
552        For example:
553            ('series', 1) returns 'series'
554            ('series', 3) returns 'series3'
555
556        Args:
557            in_name (str): Series name
558            version (int): Version number
559
560        Return:
561            str: associated branch name
562        """
563        if version == 1:
564            return in_name
565        return f'{in_name}{version}'
566
567    def _parse_series(self, name, include_archived=False):
568        """Parse the name of a series, or detect it from the current branch
569
570        Args:
571            name (str or None): name of series
572            include_archived (bool): True to search in archives series
573
574        Return:
575            Series: New object with the name set; idnum is also set if the
576                series exists in the database
577        """
578        if not name:
579            name = gitutil.get_branch(self.gitdir)
580        name, _ = split_name_version(name)
581        ser = self.get_series_by_name(name, include_archived)
582        if not ser:
583            ser = Series()
584            ser.name = name
585        return ser
586
587    def _parse_series_and_version(self, in_name, in_version):
588        """Parse name and version of a series, or detect from current branch
589
590        Figures out the name from in_name, or if that is None, from the current
591            branch.
592
593        Uses the version in_version, or if that is None, uses the int at the
594        end of the name (e.g. 'series' is version 1, 'series4' is version 4)
595
596        Args:
597            in_name (str or None): name of series
598            in_version (str or None): version of series
599
600        Return:
601            tuple:
602                Series: New object with the name set; idnum is also set if the
603                    series exists in the database
604                int: Series version-number detected from the name
605                    (e.g. 'fred' is version 1, 'fred2' is version 2)
606        """
607        name = in_name
608        if not name:
609            name = gitutil.get_branch(self.gitdir)
610            if not name:
611                raise ValueError('No branch detected: please use -s <series>')
612        name, version = split_name_version(name)
613        if not name:
614            raise ValueError(f"Series name '{in_name}' cannot be a number, "
615                             f"use '<name><version>'")
616        if in_version:
617            if version and version != in_version:
618                tout.warning(
619                    f"Version mismatch: -V has {in_version} but branch name "
620                    f'indicates {version}')
621            version = in_version
622        if not version:
623            version = 1
624        if version > 99:
625            raise ValueError(f"Version {version} exceeds 99")
626        ser = self.get_series_by_name(name)
627        if not ser:
628            ser = Series()
629            ser.name = name
630        return ser, version
631
632    def _series_get_version_stats(self, idnum, vers):
633        """Get the stats for a series
634
635        Args:
636            idnum (int): ID number of series to process
637            vers (int): Version number to process
638
639        Return:
640            tuple:
641                str: Status string, '<accepted>/<count>'
642                OrderedDict:
643                    key (int): record ID if find_svid is None, else seq
644                    value (PCOMMIT): record data
645        """
646        svid, link = self._get_series_svid_link(idnum, vers)
647        pwc = self.get_pcommit_dict(svid)
648        count = len(pwc.values())
649        if link:
650            accepted = 0
651            for pcm in pwc.values():
652                accepted += pcm.state == 'accepted'
653        else:
654            accepted = '-'
655        return f'{accepted}/{count}', pwc
656
657    def get_series_svid(self, series_id, version):
658        """Get the patchwork ID of a series version
659
660        Args:
661            series_id (int): id of the series to look up
662            version (int): version number to look up
663
664        Return:
665            str: link found
666
667        Raises:
668            ValueError: No matching series found
669        """
670        return self._get_series_svid_link(series_id, version)[0]
671
672    def _get_series_svid_link(self, series_id, version):
673        """Get the patchwork ID of a series version
674
675        Args:
676            series_id (int): series ID to look up
677            version (int): version number to look up
678
679        Return:
680            tuple:
681                int: record id
682                str: link
683        """
684        recs = self.get_ser_ver(series_id, version)
685        return recs.idnum, recs.link
686
687    def get_ser_ver(self, series_id, version):
688        """Get the patchwork details for a series version
689
690        Args:
691            series_id (int): series ID to look up
692            version (int): version number to look up
693
694        Return:
695            SER_VER: Requested information
696
697        Raises:
698            ValueError: There is no matching idnum/version
699        """
700        return self.db.ser_ver_get_for_series(series_id, version)
701
702    def _prepare_process(self, name, count, new_name=None, quiet=False):
703        """Get ready to process all commits in a branch
704
705        Args:
706            name (str): Name of the branch to process
707            count (int): Number of commits
708            new_name (str or None): New name, if a new branch is to be created
709            quiet (bool): True to avoid output (used for testing)
710
711        Return: tuple:
712            pygit2.repo: Repo to use
713            pygit2.oid: Upstream commit, onto which commits should be added
714            Pygit2.branch: Original branch, for later use
715            str: (Possibly new) name of branch to process
716            list of Commit: commits to process, in order
717            pygit2.Reference: Original head before processing started
718        """
719        upstream_guess = gitutil.get_upstream(self.gitdir, name)[0]
720
721        tout.debug(f"_process_series name '{name}' new_name '{new_name}' "
722                   f"upstream_guess '{upstream_guess}'")
723        dirty = gitutil.check_dirty(self.gitdir, self.topdir)
724        if dirty:
725            raise ValueError(
726                f"Modified files exist: use 'git status' to check: "
727                f'{dirty[:5]}')
728        repo = pygit2.init_repository(self.gitdir)
729
730        commit = None
731        upstream_name = None
732        if upstream_guess:
733            try:
734                upstream = repo.lookup_reference(upstream_guess)
735                upstream_name = upstream.name
736                commit = upstream.peel(pygit2.enums.ObjectType.COMMIT)
737            except KeyError:
738                pass
739            except pygit2.repository.InvalidSpecError as exc:
740                print(f"Error '{exc}'")
741        if not upstream_name:
742            upstream_name = f'{name}~{count}'
743            commit = repo.revparse_single(upstream_name)
744
745        branch = repo.lookup_branch(name)
746        if not quiet:
747            tout.info(
748                f'Checking out upstream commit {upstream_name}: '
749                f'{oid(commit.oid)}')
750
751        old_head = repo.head
752        if old_head.shorthand == name:
753            old_head = None
754        else:
755            old_head = repo.head
756
757        if new_name:
758            name = new_name
759        repo.set_head(commit.oid)
760
761        commits = []
762        cmt = repo.get(branch.target)
763        for _ in range(count):
764            commits.append(cmt)
765            cmt = cmt.parents[0]
766
767        return (repo, repo.head, branch, name, commit, list(reversed(commits)),
768                old_head)
769
770    def _pick_commit(self, repo, cmt):
771        """Apply a commit to the source tree, without committing it
772
773        _prepare_process() must be called before starting to pick commits
774
775        This function must be called before _finish_commit()
776
777        Note that this uses a cherry-pick method, creating a new tree_id each
778        time, so can make source-code changes
779
780        Args:
781            repo (pygit2.repo): Repo to use
782            cmt (Commit): Commit to apply
783
784        Return: tuple:
785            tree_id (pygit2.oid): Oid of index with source-changes applied
786            commit (pygit2.oid): Old commit being cherry-picked
787        """
788        tout.detail(f"- adding {oid(cmt.hash)} {cmt}")
789        repo.cherrypick(cmt.hash)
790        if repo.index.conflicts:
791            raise ValueError('Conflicts detected')
792
793        tree_id = repo.index.write_tree()
794        cherry = repo.get(cmt.hash)
795        tout.detail(f"cherry {oid(cherry.oid)}")
796        return tree_id, cherry
797
798    def _finish_commit(self, repo, tree_id, commit, cur, msg=None):
799        """Complete a commit
800
801        This must be called after _pick_commit().
802
803        Args:
804            repo (pygit2.repo): Repo to use
805            tree_id (pygit2.oid): Oid of index with source-changes applied; if
806                None then the existing commit.tree_id is used
807            commit (pygit2.oid): Old commit being cherry-picked
808            cur (pygit2.reference): Reference to parent to use for the commit
809            msg (str): Commit subject and message; None to use commit.message
810        """
811        if msg is None:
812            msg = commit.message
813        if not tree_id:
814            tree_id = commit.tree_id
815        repo.create_commit('HEAD', commit.author, commit.committer,
816                           msg, tree_id, [cur.target])
817        return repo.head
818
819    def _finish_process(self, repo, branch, name, cur, old_head, new_name=None,
820                        switch=False, dry_run=False, quiet=False):
821        """Finish processing commits
822
823        Args:
824            repo (pygit2.repo): Repo to use
825            branch (pygit2.branch): Branch returned by _prepare_process()
826            name (str): Name of the branch to process
827            new_name (str or None): New name, if a new branch is being created
828            switch (bool): True to switch to the new branch after processing;
829                otherwise HEAD remains at the original branch, as amended
830            dry_run (bool): True to do a dry run, restoring the original tree
831                afterwards
832            quiet (bool): True to avoid output (used for testing)
833
834        Return:
835            pygit2.reference: Final commit after everything is completed
836        """
837        repo.state_cleanup()
838
839        # Update the branch
840        target = repo.revparse_single('HEAD')
841        if not quiet:
842            tout.info(f'Updating branch {name} from {oid(branch.target)} to '
843                      f'{str(target.oid)[:HASH_LEN]}')
844        if dry_run:
845            if new_name:
846                repo.head.set_target(branch.target)
847            else:
848                branch_oid = branch.peel(pygit2.enums.ObjectType.COMMIT).oid
849                repo.head.set_target(branch_oid)
850            repo.head.set_target(branch.target)
851            repo.set_head(branch.name)
852        else:
853            if new_name:
854                new_branch = repo.branches.create(new_name, target)
855                if branch.upstream:
856                    new_branch.upstream = branch.upstream
857                branch = new_branch
858            else:
859                branch.set_target(cur.target)
860            repo.set_head(branch.name)
861        if old_head:
862            if not switch:
863                repo.set_head(old_head.name)
864        return target
865
866    def make_change_id(self, commit):
867        """Make a Change ID for a commit
868
869        This is similar to the gerrit script:
870        git var GIT_COMMITTER_IDENT ; echo "$refhash" ; cat "README"; }
871            | git hash-object --stdin)
872
873        Args:
874            commit (pygit2.commit): Commit to process
875
876        Return:
877            Change ID in hex format
878        """
879        sig = commit.committer
880        val = hashlib.sha1()
881        to_hash = f'{sig.name} <{sig.email}> {sig.time} {sig.offset}'
882        val.update(to_hash.encode('utf-8'))
883        val.update(str(commit.tree_id).encode('utf-8'))
884        val.update(commit.message.encode('utf-8'))
885        return val.hexdigest()
886
887    def _filter_commits(self, name, series, seq_to_drop):
888        """Filter commits to drop one
889
890        This function rebases the current branch, dropping a single commit,
891        thus changing the resulting code in the tree.
892
893        Args:
894            name (str): Name of the branch to process
895            series (Series): Series object
896            seq_to_drop (int): Commit sequence to drop; commits are numbered
897                from 0, which is the one after the upstream branch, to
898                count - 1
899        """
900        count = len(series.commits)
901        (repo, cur, branch, name, commit, _, _) = self._prepare_process(
902            name, count, quiet=True)
903        repo.checkout_tree(commit, strategy=CheckoutStrategy.FORCE |
904                           CheckoutStrategy.RECREATE_MISSING)
905        repo.set_head(commit.oid)
906        for seq, cmt in enumerate(series.commits):
907            if seq != seq_to_drop:
908                tree_id, cherry = self._pick_commit(repo, cmt)
909                cur = self._finish_commit(repo, tree_id, cherry, cur)
910        self._finish_process(repo, branch, name, cur, None, quiet=True)
911
912    def process_series(self, name, series, new_name=None, switch=False,
913                       dry_run=False):
914        """Rewrite a series commit messages, leaving code alone
915
916        This uses a 'vals' namespace to pass things to the controlling
917        function.
918
919        Each time _process_series() yields, it sets up:
920            commit (Commit): The pygit2 commit that is being processed
921            msg (str): Commit message, which can be modified
922            info (str): Initially empty; the controlling function can add a
923                short message here which will be shown to the user
924            final (bool): True if this is the last commit to apply
925            seq (int): Current sequence number in the commits to apply (0,,n-1)
926
927            It also sets git HEAD at the commit before this commit being
928            processed
929
930        The function can change msg and info, e.g. to add or remove tags from
931        the commit.
932
933        Args:
934            name (str): Name of the branch to process
935            series (Series): Series object
936            new_name (str or None): New name, if a new branch is to be created
937            switch (bool): True to switch to the new branch after processing;
938                otherwise HEAD remains at the original branch, as amended
939            dry_run (bool): True to do a dry run, restoring the original tree
940                afterwards
941
942        Return:
943            pygit.oid: oid of the new branch
944        """
945        count = len(series.commits)
946        repo, cur, branch, name, _, commits, old_head = self._prepare_process(
947            name, count, new_name)
948        vals = SimpleNamespace()
949        vals.final = False
950        tout.info(f"Processing {count} commits from branch '{name}'")
951
952        # Record the message lines
953        lines = []
954        for seq, cmt in enumerate(series.commits):
955            commit = commits[seq]
956            vals.commit = commit
957            vals.msg = commit.message
958            vals.info = ''
959            vals.final = seq == len(series.commits) - 1
960            vals.seq = seq
961            yield vals
962
963            cur = self._finish_commit(repo, None, commit, cur, vals.msg)
964            lines.append([vals.info.strip(),
965                          f'{oid(cmt.hash)} as {oid(cur.target)} {cmt}'])
966
967        max_len = max(len(info) for info, rest in lines) + 1
968        for info, rest in lines:
969            if info:
970                info += ':'
971            tout.info(f'- {info.ljust(max_len)} {rest}')
972        target = self._finish_process(repo, branch, name, cur, old_head,
973                                      new_name, switch, dry_run)
974        vals.oid = target.oid
975
976    def _mark_series(self, name, series, dry_run=False):
977        """Mark a series with Change-Id tags
978
979        Args:
980            name (str): Name of the series to mark
981            series (Series): Series object
982            dry_run (bool): True to do a dry run, restoring the original tree
983                afterwards
984
985        Return:
986            pygit.oid: oid of the new branch
987        """
988        vals = None
989        for vals in self.process_series(name, series, dry_run=dry_run):
990            if CHANGE_ID_TAG not in vals.msg:
991                change_id = self.make_change_id(vals.commit)
992                vals.msg = vals.msg + f'\n{CHANGE_ID_TAG}: {change_id}'
993                tout.detail("   - adding mark")
994                vals.info = 'marked'
995            else:
996                vals.info = 'has mark'
997
998        return vals.oid
999
1000    def update_series(self, branch_name, series, max_vers, new_name=None,
1001                       dry_run=False, add_vers=None, add_link=None,
1002                       add_rtags=None, switch=False):
1003        """Rewrite a series to update the Series-version/Series-links lines
1004
1005        This updates the series in git; it does not update the database
1006
1007        Args:
1008            branch_name (str): Name of the branch to process
1009            series (Series): Series object
1010            max_vers (int): Version number of the series being updated
1011            new_name (str or None): New name, if a new branch is to be created
1012            dry_run (bool): True to do a dry run, restoring the original tree
1013                afterwards
1014            add_vers (int or None): Version number to add to the series, if any
1015            add_link (str or None): Link to add to the series, if any
1016            add_rtags (list of dict): List of review tags to add, one item for
1017                    each commit, each a dict:
1018                key: Response tag (e.g. 'Reviewed-by')
1019                value: Set of people who gave that response, each a name/email
1020                    string
1021            switch (bool): True to switch to the new branch after processing;
1022                otherwise HEAD remains at the original branch, as amended
1023
1024        Return:
1025            pygit.oid: oid of the new branch
1026        """
1027        def _do_version():
1028            if add_vers:
1029                if add_vers == 1:
1030                    vals.info += f'rm v{add_vers} '
1031                else:
1032                    vals.info += f'add v{add_vers} '
1033                    out.append(f'Series-version: {add_vers}')
1034
1035        def _do_links(new_links):
1036            if add_link:
1037                if 'add' not in vals.info:
1038                    vals.info += 'add '
1039                vals.info += f"links '{new_links}' "
1040            else:
1041                vals.info += f"upd links '{new_links}' "
1042            out.append(f'Series-links: {new_links}')
1043
1044        added_version = False
1045        added_link = False
1046        for vals in self.process_series(branch_name, series, new_name, switch,
1047                                         dry_run):
1048            out = []
1049            for line in vals.msg.splitlines():
1050                m_ver = re.match('Series-version:(.*)', line)
1051                m_links = re.match('Series-links:(.*)', line)
1052                if m_ver and add_vers:
1053                    if ('version' in series and
1054                            int(series.version) != max_vers):
1055                        tout.warning(
1056                            f'Branch {branch_name}: Series-version tag '
1057                            f'{series.version} does not match expected '
1058                            f'version {max_vers}')
1059                    _do_version()
1060                    added_version = True
1061                elif m_links:
1062                    links = series.get_links(m_links.group(1), max_vers)
1063                    if add_link:
1064                        links[max_vers] = add_link
1065                    _do_links(series.build_links(links))
1066                    added_link = True
1067                else:
1068                    out.append(line)
1069            if vals.final:
1070                if not added_version and add_vers and add_vers > 1:
1071                    _do_version()
1072                if not added_link and add_link:
1073                    _do_links(f'{max_vers}:{add_link}')
1074
1075            vals.msg = '\n'.join(out) + '\n'
1076            if add_rtags and add_rtags[vals.seq]:
1077                lines = []
1078                for tag, people in add_rtags[vals.seq].items():
1079                    for who in people:
1080                        lines.append(f'{tag}: {who}')
1081                vals.msg = patchstream.insert_tags(vals.msg.rstrip(),
1082                                                   sorted(lines))
1083                vals.info += (f'added {len(lines)} '
1084                              f"tag{'' if len(lines) == 1 else 's'}")
1085
1086    def _build_col(self, state, prefix='', base_str=None):
1087        """Build a patch-state string with colour
1088
1089        Args:
1090            state (str): State to colourise (also indicates the colour to use)
1091            prefix (str): Prefix string to also colourise
1092            base_str (str or None): String to show instead of state, or None to
1093                show state
1094
1095        Return:
1096            str: String with ANSI colour characters
1097        """
1098        bright = True
1099        if state == 'accepted':
1100            col = self.col.GREEN
1101        elif state == 'awaiting-upstream':
1102            bright = False
1103            col = self.col.GREEN
1104        elif state in ['changes-requested']:
1105            col = self.col.CYAN
1106        elif state in ['rejected', 'deferred', 'not-applicable', 'superseded',
1107                       'handled-elsewhere']:
1108            col = self.col.RED
1109        elif not state:
1110            state = 'unknown'
1111            col = self.col.MAGENTA
1112        else:
1113            # under-review, rfc, needs-review-ack
1114            col = self.col.WHITE
1115        out = base_str or SHORTEN_STATE.get(state, state)
1116        pad = ' ' * (10 - len(out))
1117        col_state = self.col.build(col, prefix + out, bright)
1118        return col_state, pad
1119
1120    def _get_patches(self, series, version):
1121        """Get a Series object containing the patches in a series
1122
1123        Args:
1124            series (str): Name of series to use, or None to use current branch
1125            version (int): Version number, or None to detect from name
1126
1127        Return: tuple:
1128            str: Name of branch, e.g. 'mary2'
1129            Series: Series object containing the commits and idnum, desc, name
1130            int: Version number of series, e.g. 2
1131            OrderedDict:
1132                key (int): record ID if find_svid is None, else seq
1133                value (PCOMMIT): record data
1134            str: series name (for this version)
1135            str: patchwork link
1136            str: cover_id
1137            int: cover_num_comments
1138        """
1139        ser, version = self._parse_series_and_version(series, version)
1140        if not ser.idnum:
1141            raise ValueError(f"Unknown series '{series}'")
1142        self._ensure_version(ser, version)
1143        svinfo = self.get_ser_ver(ser.idnum, version)
1144        pwc = self.get_pcommit_dict(svinfo.idnum)
1145
1146        count = len(pwc)
1147        branch = self._join_name_version(ser.name, version)
1148        series = patchstream.get_metadata(branch, 0, count,
1149                                          git_dir=self.gitdir)
1150        self._copy_db_fields_to(series, ser)
1151
1152        return (branch, series, version, pwc, svinfo.name, svinfo.link,
1153                svinfo.cover_id, svinfo.cover_num_comments)
1154
1155    def _list_patches(self, branch, pwc, series, desc, cover_id, num_comments,
1156                      show_commit, show_patch, list_patches, state_totals):
1157        """List patches along with optional status info
1158
1159        Args:
1160            branch (str): Branch name        if self.show_progress
1161            pwc (dict): pcommit records:
1162                key (int): seq
1163                value (PCOMMIT): Record from database
1164            series (Series): Series to show, or None to just use the database
1165            desc (str): Series title
1166            cover_id (int): Cover-letter ID
1167            num_comments (int): The number of comments on the cover letter
1168            show_commit (bool): True to show the commit and diffstate
1169            show_patch (bool): True to show the patch
1170            list_patches (bool): True to list all patches for each series,
1171                False to just show the series summary on a single line
1172            state_totals (dict): Holds totals for each state across all patches
1173                key (str): state name
1174                value (int): Number of patches in that state
1175
1176        Return:
1177            bool: True if OK, False if any commit subjects don't match their
1178                patchwork subjects
1179        """
1180        lines = []
1181        states = defaultdict(int)
1182        count = len(pwc)
1183        ok = True
1184        for seq, item in enumerate(pwc.values()):
1185            if series:
1186                cmt = series.commits[seq]
1187                if cmt.subject != item.subject:
1188                    ok = False
1189
1190            col_state, pad = self._build_col(item.state)
1191            patch_id = item.patch_id if item.patch_id else ''
1192            if item.num_comments:
1193                comments = str(item.num_comments)
1194            elif item.num_comments is None:
1195                comments = '-'
1196            else:
1197                comments = ''
1198
1199            if show_commit or show_patch:
1200                subject = self.col.build(self.col.BLACK, item.subject,
1201                                         bright=False, back=self.col.YELLOW)
1202            else:
1203                subject = item.subject
1204
1205            line = (f'{seq:3} {col_state}{pad} {comments.rjust(3)} '
1206                    f'{patch_id:7} {oid(cmt.hash)} {subject}')
1207            lines.append(line)
1208            states[item.state] += 1
1209        out = ''
1210        for state, freq in states.items():
1211            out += ' ' + self._build_col(state, f'{freq}:')[0]
1212            state_totals[state] += freq
1213        name = ''
1214        if not list_patches:
1215            name = desc or series.desc
1216            name = self.col.build(self.col.YELLOW, name[:41].ljust(41))
1217            if not ok:
1218                out = '*' + out[1:]
1219            print(f"{branch:16} {name} {len(pwc):5} {out}")
1220            return ok
1221        print(f"Branch '{branch}' (total {len(pwc)}):{out}{name}")
1222
1223        print(self.col.build(
1224            self.col.MAGENTA,
1225            f"Seq State      Com PatchId {'Commit'.ljust(HASH_LEN)} Subject"))
1226
1227        comments = '' if num_comments is None else str(num_comments)
1228        if desc or comments or cover_id:
1229            cov = 'Cov' if cover_id else ''
1230            print(self.col.build(
1231                self.col.WHITE,
1232                f"{cov:14} {comments.rjust(3)} {cover_id or '':7}            "
1233                f'{desc or series.desc}',
1234                bright=False))
1235        for seq in range(count):
1236            line = lines[seq]
1237            print(line)
1238            if show_commit or show_patch:
1239                print()
1240                cmt = series.commits[seq] if series else ''
1241                msg = gitutil.show_commit(
1242                    cmt.hash, show_commit, True, show_patch,
1243                    colour=self.col.enabled(), git_dir=self.gitdir)
1244                sys.stdout.write(msg)
1245                if seq != count - 1:
1246                    print()
1247                    print()
1248
1249        return ok
1250
1251    def _find_matched_commit(self, commits, pcm):
1252        """Find a commit in a list of possible matches
1253
1254        Args:
1255            commits (dict of Commit): Possible matches
1256                key (int): sequence number of patch (from 0)
1257                value (Commit): Commit object
1258            pcm (PCOMMIT): Patch to check
1259
1260        Return:
1261            int: Sequence number of matching commit, or None if not found
1262        """
1263        for seq, cmt in commits.items():
1264            tout.debug(f"- match subject: '{cmt.subject}'")
1265            if pcm.subject == cmt.subject:
1266                return seq
1267        return None
1268
1269    def _find_matched_patch(self, patches, cmt):
1270        """Find a patch in a list of possible matches
1271
1272        Args:
1273            patches: dict of ossible matches
1274                key (int): sequence number of patch
1275                value (PCOMMIT): patch
1276            cmt (Commit): Commit to check
1277
1278        Return:
1279            int: Sequence number of matching patch, or None if not found
1280        """
1281        for seq, pcm in patches.items():
1282            tout.debug(f"- match subject: '{pcm.subject}'")
1283            if cmt.subject == pcm.subject:
1284                return seq
1285        return None
1286
1287    def _sync_one(self, svid, series_name, version, show_comments,
1288                  show_cover_comments, gather_tags, cover, patches, dry_run):
1289        """Sync one series to the database
1290
1291        Args:
1292            svid (int): Ser/ver ID
1293            cover (dict or None): Cover letter from patchwork, with keys:
1294                id (int): Cover-letter ID in patchwork
1295                num_comments (int): Number of comments
1296                name (str): Cover-letter name
1297            patches (list of Patch): Patches in the series
1298        """
1299        pwc = self.get_pcommit_dict(svid)
1300        if gather_tags:
1301            count = len(pwc)
1302            branch = self._join_name_version(series_name, version)
1303            series = patchstream.get_metadata(branch, 0, count,
1304                                              git_dir=self.gitdir)
1305
1306            _, new_rtag_list = status.do_show_status(
1307                series, cover, patches, show_comments, show_cover_comments,
1308                self.col, warnings_on_stderr=False)
1309            self.update_series(branch, series, version, None, dry_run,
1310                                add_rtags=new_rtag_list)
1311
1312        updated = 0
1313        for seq, item in enumerate(pwc.values()):
1314            if seq >= len(patches):
1315                continue
1316            patch = patches[seq]
1317            if patch.id:
1318                if self.db.pcommit_update(
1319                    Pcommit(item.idnum, seq, None, None, None, patch.state,
1320                            patch.id, len(patch.comments))):
1321                    updated += 1
1322        if cover:
1323            info = SerVer(svid, None, None, None, cover.id,
1324                           cover.num_comments, cover.name, None)
1325        else:
1326            info = SerVer(svid, None, None, None, None, None, patches[0].name,
1327                           None)
1328        self.db.ser_ver_set_info(info)
1329
1330        return updated, 1 if cover else 0
1331
1332    async def _gather(self, pwork, link, show_cover_comments):
1333        """Sync the series status from patchwork
1334
1335        Creates a new client sesion and calls _sync()
1336
1337        Args:
1338            pwork (Patchwork): Patchwork object to use
1339            link (str): Patchwork link for the series
1340            show_cover_comments (bool): True to show the comments on the cover
1341                letter
1342
1343        Return: tuple:
1344            COVER object, or None if none or not read_cover_comments
1345            list of PATCH objects
1346        """
1347        async with aiohttp.ClientSession() as client:
1348            return await pwork.series_get_state(client, link, True,
1349                                                show_cover_comments)
1350
1351    def _get_fetch_dict(self, sync_all_versions):
1352        """Get a dict of ser_vers to fetch, along with their patchwork links
1353
1354        Args:
1355            sync_all_versions (bool): True to sync all versions of a series,
1356                False to sync only the latest version
1357
1358        Return: tuple:
1359            dict: things to fetch
1360                key (int): svid
1361                value (str): patchwork link for the series
1362            int: number of series which are missing a link
1363        """
1364        missing = 0
1365        svdict = self.get_ser_ver_dict()
1366        sdict = self.db.series_get_dict_by_id()
1367        to_fetch = {}
1368
1369        if sync_all_versions:
1370            for svinfo in self.get_ser_ver_list():
1371                ser_ver = svdict[svinfo.idnum]
1372                if svinfo.link:
1373                    to_fetch[svinfo.idnum] = patchwork.STATE_REQ(
1374                        svinfo.link, svinfo.series_id,
1375                        sdict[svinfo.series_id].name, svinfo.version, False,
1376                        False)
1377                else:
1378                    missing += 1
1379        else:
1380            # Find the maximum version for each series
1381            max_vers = self._series_all_max_versions()
1382
1383            # Get a list of links to fetch
1384            for svid, series_id, version in max_vers:
1385                ser_ver = svdict[svid]
1386                if series_id not in sdict:
1387                    # skip archived item
1388                    continue
1389                if ser_ver.link:
1390                    to_fetch[svid] = patchwork.STATE_REQ(
1391                        ser_ver.link, series_id, sdict[series_id].name,
1392                        version, False, False)
1393                else:
1394                    missing += 1
1395
1396        # order by series name, version
1397        ordered = OrderedDict()
1398        for svid in sorted(
1399                to_fetch,
1400                key=lambda k: (to_fetch[k].series_name, to_fetch[k].version)):
1401            sync = to_fetch[svid]
1402            ordered[svid] = sync
1403
1404        return ordered, missing
1405
1406    async def _sync_all(self, client, pwork, to_fetch):
1407        """Sync all series status from patchwork
1408
1409        Args:
1410            pwork (Patchwork): Patchwork object to use
1411            sync_all_versions (bool): True to sync all versions of a series,
1412                False to sync only the latest version
1413            gather_tags (bool): True to gather review/test tags
1414
1415        Return: list of tuple:
1416            COVER object, or None if none or not read_cover_comments
1417            list of PATCH objects
1418        """
1419        with pwork.collect_stats() as stats:
1420            tasks = [pwork.series_get_state(client, sync.link, True, True)
1421                     for sync in to_fetch.values() if sync.link]
1422            result = await asyncio.gather(*tasks)
1423        return result, stats.request_count
1424
1425    async def _do_series_sync_all(self, pwork, to_fetch):
1426        async with aiohttp.ClientSession() as client:
1427            return await self._sync_all(client, pwork, to_fetch)
1428
1429    def _progress_one(self, ser, show_all_versions, list_patches,
1430                      state_totals):
1431        """Show progress information for all versions in a series
1432
1433        Args:
1434            ser (Series): Series to use
1435            show_all_versions (bool): True to show all versions of a series,
1436                False to show only the final version
1437            list_patches (bool): True to list all patches for each series,
1438                False to just show the series summary on a single line
1439            state_totals (dict): Holds totals for each state across all patches
1440                key (str): state name
1441                value (int): Number of patches in that state
1442
1443        Return: tuple
1444            int: Number of series shown
1445            int: Number of patches shown
1446            int: Number of version which need a 'scan'
1447        """
1448        max_vers = self._series_max_version(ser.idnum)
1449        name, desc = self._get_series_info(ser.idnum)
1450        coloured = self.col.build(self.col.BLACK, desc, bright=False,
1451                                  back=self.col.YELLOW)
1452        versions = self._get_version_list(ser.idnum)
1453        vstr = list(map(str, versions))
1454
1455        if list_patches:
1456            print(f"{name}: {coloured} (versions: {' '.join(vstr)})")
1457        add_blank_line = False
1458        total_series = 0
1459        total_patches = 0
1460        need_scan = 0
1461        for ver in versions:
1462            if not show_all_versions and ver != max_vers:
1463                continue
1464            if add_blank_line:
1465                print()
1466            _, pwc = self._series_get_version_stats(ser.idnum, ver)
1467            count = len(pwc)
1468            branch = self._join_name_version(ser.name, ver)
1469            series = patchstream.get_metadata(branch, 0, count,
1470                                              git_dir=self.gitdir)
1471            svinfo = self.get_ser_ver(ser.idnum, ver)
1472            self._copy_db_fields_to(series, ser)
1473
1474            ok = self._list_patches(
1475                branch, pwc, series, svinfo.name, svinfo.cover_id,
1476                svinfo.cover_num_comments, False, False, list_patches,
1477                state_totals)
1478            if not ok:
1479                need_scan += 1
1480            add_blank_line = list_patches
1481            total_series += 1
1482            total_patches += count
1483        return total_series, total_patches, need_scan
1484
1485    def _summary_one(self, ser):
1486        """Show summary information for the latest version in a series
1487
1488        Args:
1489            series (str): Name of series to use, or None to show progress for
1490                all series
1491        """
1492        max_vers = self._series_max_version(ser.idnum)
1493        name, desc = self._get_series_info(ser.idnum)
1494        stats, pwc = self._series_get_version_stats(ser.idnum, max_vers)
1495        states = {x.state for x in pwc.values()}
1496        state = 'accepted'
1497        for val in ['awaiting-upstream', 'changes-requested', 'rejected',
1498                    'deferred', 'not-applicable', 'superseded',
1499                    'handled-elsewhere']:
1500            if val in states:
1501                state = val
1502        state_str, pad = self._build_col(state, base_str=name)
1503        print(f"{state_str}{pad}  {stats.rjust(6)}  {desc}")
1504
1505    def _series_max_version(self, idnum):
1506        """Find the latest version of a series
1507
1508        Args:
1509            idnum (int): Series ID to look up
1510
1511        Return:
1512            int: maximum version
1513        """
1514        return self.db.series_get_max_version(idnum)
1515
1516    def _series_all_max_versions(self):
1517        """Find the latest version of all series
1518
1519        Return: list of:
1520            int: ser_ver ID
1521            int: series ID
1522            int: Maximum version
1523        """
1524        return self.db.series_get_all_max_versions()
1525