1# SPDX-License-Identifier: GPL-2.0+
2#
3# Copyright 2025 Simon Glass <sjg@chromium.org>
4#
5"""Handles the patman database
6
7This uses sqlite3 with a local file.
8
9To adjsut the schema, increment LATEST, create a migrate_to_v<x>() function
10and write some code in migrate_to() to call it.
11"""
12
13from collections import namedtuple, OrderedDict
14import os
15import sqlite3
16
17from u_boot_pylib import tools
18from u_boot_pylib import tout
19from patman.series import Series
20
21# Schema version (version 0 means there is no database yet)
22LATEST = 4
23
24# Information about a series/version record
25SerVer = namedtuple(
26    'SER_VER',
27    'idnum,series_id,version,link,cover_id,cover_num_comments,name,'
28    'archive_tag')
29
30# Record from the pcommit table:
31# idnum (int): record ID
32# seq (int): Patch sequence in series (0 is first)
33# subject (str): patch subject
34# svid (int): ID of series/version record in ser_ver table
35# change_id (str): Change-ID value
36# state (str): Current status in patchwork
37# patch_id (int): Patchwork's patch ID for this patch
38# num_comments (int): Number of comments attached to the commit
39Pcommit = namedtuple(
40    'PCOMMIT',
41    'idnum,seq,subject,svid,change_id,state,patch_id,num_comments')
42
43
44class Database:
45    """Database of information used by patman"""
46
47    # dict of databases:
48    #   key: filename
49    #   value: Database object
50    instances = {}
51
52    def __init__(self, db_path):
53        """Set up a new database object
54
55        Args:
56            db_path (str): Path to the database
57        """
58        if db_path in Database.instances:
59            # Two connections to the database can cause:
60            # sqlite3.OperationalError: database is locked
61            raise ValueError(f"There is already a database for '{db_path}'")
62        self.con = None
63        self.cur = None
64        self.db_path = db_path
65        self.is_open = False
66        Database.instances[db_path] = self
67
68    @staticmethod
69    def get_instance(db_path):
70        """Get the database instance for a path
71
72        This is provides to ensure that different callers can obtain the
73        same database object when accessing the same database file.
74
75        Args:
76            db_path (str): Path to the database
77
78        Return:
79            Database: Database instance, which is created if necessary
80        """
81        db = Database.instances.get(db_path)
82        if db:
83            return db, False
84        return Database(db_path), True
85
86    def start(self):
87        """Open the database read for use, migrate to latest schema"""
88        self.open_it()
89        self.migrate_to(LATEST)
90
91    def open_it(self):
92        """Open the database, creating it if necessary"""
93        if self.is_open:
94            raise ValueError('Already open')
95        if not os.path.exists(self.db_path):
96            tout.warning(f'Creating new database {self.db_path}')
97        self.con = sqlite3.connect(self.db_path)
98        self.cur = self.con.cursor()
99        self.is_open = True
100
101    def close(self):
102        """Close the database"""
103        if not self.is_open:
104            raise ValueError('Already closed')
105        self.con.close()
106        self.cur = None
107        self.con = None
108        self.is_open = False
109
110    def create_v1(self):
111        """Create a database with the v1 schema"""
112        self.cur.execute(
113            'CREATE TABLE series (id INTEGER PRIMARY KEY AUTOINCREMENT,'
114            'name UNIQUE, desc, archived BIT)')
115
116        # Provides a series_id/version pair, which is used to refer to a
117        # particular series version sent to patchwork. This stores the link
118        # to patchwork
119        self.cur.execute(
120            'CREATE TABLE ser_ver (id INTEGER PRIMARY KEY AUTOINCREMENT,'
121            'series_id INTEGER, version INTEGER, link,'
122            'FOREIGN KEY (series_id) REFERENCES series (id))')
123
124        self.cur.execute(
125            'CREATE TABLE upstream (name UNIQUE, url, is_default BIT)')
126
127        # change_id is the Change-Id
128        # patch_id is the ID of the patch on the patchwork server
129        self.cur.execute(
130            'CREATE TABLE pcommit (id INTEGER PRIMARY KEY AUTOINCREMENT,'
131            'svid INTEGER, seq INTEGER, subject, patch_id INTEGER, '
132            'change_id, state, num_comments INTEGER, '
133            'FOREIGN KEY (svid) REFERENCES ser_ver (id))')
134
135        self.cur.execute(
136            'CREATE TABLE settings (name UNIQUE, proj_id INT, link_name)')
137
138    def _migrate_to_v2(self):
139        """Add a schema_version table"""
140        self.cur.execute('CREATE TABLE schema_version (version INTEGER)')
141
142    def _migrate_to_v3(self):
143        """Store the number of cover-letter comments in the schema"""
144        self.cur.execute('ALTER TABLE ser_ver ADD COLUMN cover_id')
145        self.cur.execute('ALTER TABLE ser_ver ADD COLUMN cover_num_comments '
146                         'INTEGER')
147        self.cur.execute('ALTER TABLE ser_ver ADD COLUMN name')
148
149    def _migrate_to_v4(self):
150        """Add an archive tag for each ser_ver"""
151        self.cur.execute('ALTER TABLE ser_ver ADD COLUMN archive_tag')
152
153    def migrate_to(self, dest_version):
154        """Migrate the database to the selected version
155
156        Args:
157            dest_version (int): Version to migrate to
158        """
159        while True:
160            version = self.get_schema_version()
161            if version == dest_version:
162                break
163
164            self.close()
165            tools.write_file(f'{self.db_path}old.v{version}',
166                             tools.read_file(self.db_path))
167
168            version += 1
169            tout.info(f'Update database to v{version}')
170            self.open_it()
171            if version == 1:
172                self.create_v1()
173            elif version == 2:
174                self._migrate_to_v2()
175            elif version == 3:
176                self._migrate_to_v3()
177            elif version == 4:
178                self._migrate_to_v4()
179
180            # Save the new version if we have a schema_version table
181            if version > 1:
182                self.cur.execute('DELETE FROM schema_version')
183                self.cur.execute(
184                    'INSERT INTO schema_version (version) VALUES (?)',
185                    (version,))
186            self.commit()
187
188    def get_schema_version(self):
189        """Get the version of the database's schema
190
191        Return:
192            int: Database version, 0 means there is no data; anything less than
193                LATEST means the schema is out of date and must be updated
194        """
195        # If there is no database at all, assume v0
196        version = 0
197        try:
198            self.cur.execute('SELECT name FROM series')
199        except sqlite3.OperationalError:
200            return 0
201
202        # If there is no schema, assume v1
203        try:
204            self.cur.execute('SELECT version FROM schema_version')
205            version = self.cur.fetchone()[0]
206        except sqlite3.OperationalError:
207            return 1
208        return version
209
210    def execute(self, query, parameters=()):
211        """Execute a database query
212
213        Args:
214            query (str): Query string
215            parameters (list of values): Parameters to pass
216
217        Return:
218
219        """
220        return self.cur.execute(query, parameters)
221
222    def commit(self):
223        """Commit changes to the database"""
224        self.con.commit()
225
226    def rollback(self):
227        """Roll back changes to the database"""
228        self.con.rollback()
229
230    def lastrowid(self):
231        """Get the last row-ID reported by the database
232
233        Return:
234            int: Value for lastrowid
235        """
236        return self.cur.lastrowid
237
238    def rowcount(self):
239        """Get the row-count reported by the database
240
241        Return:
242            int: Value for rowcount
243        """
244        return self.cur.rowcount
245
246    def _get_series_list(self, include_archived):
247        """Get a list of Series objects from the database
248
249        Args:
250            include_archived (bool): True to include archives series
251
252        Return:
253            list of Series
254        """
255        res = self.execute(
256            'SELECT id, name, desc FROM series ' +
257            ('WHERE archived = 0' if not include_archived else ''))
258        return [Series.from_fields(idnum=idnum, name=name, desc=desc)
259                for idnum, name, desc in res.fetchall()]
260
261    # series functions
262
263    def series_get_dict_by_id(self, include_archived=False):
264        """Get a dict of Series objects from the database
265
266        Args:
267            include_archived (bool): True to include archives series
268
269        Return:
270            OrderedDict:
271                key: series ID
272                value: Series with idnum, name and desc filled out
273        """
274        sdict = OrderedDict()
275        for ser in self._get_series_list(include_archived):
276            sdict[ser.idnum] = ser
277        return sdict
278
279    def series_find_by_name(self, name, include_archived=False):
280        """Find a series and return its details
281
282        Args:
283            name (str): Name to search for
284            include_archived (bool): True to include archives series
285
286        Returns:
287            idnum, or None if not found
288        """
289        res = self.execute(
290            'SELECT id FROM series WHERE name = ?' +
291            ('AND archived = 0' if not include_archived else ''), (name,))
292        recs = res.fetchall()
293
294        # This shouldn't happen
295        assert len(recs) <= 1, 'Expected one match, but multiple found'
296
297        if len(recs) != 1:
298            return None
299        return recs[0][0]
300
301    def series_get_info(self, idnum):
302        """Get information for a series from the database
303
304        Args:
305            idnum (int): Series ID to look up
306
307        Return: tuple:
308            str: Series name
309            str: Series description
310
311        Raises:
312            ValueError: Series is not found
313        """
314        res = self.execute('SELECT name, desc FROM series WHERE id = ?',
315                           (idnum,))
316        recs = res.fetchall()
317        if len(recs) != 1:
318            raise ValueError(f'No series found (id {idnum} len {len(recs)})')
319        return recs[0]
320
321    def series_get_dict(self, include_archived=False):
322        """Get a dict of Series objects from the database
323
324        Args:
325            include_archived (bool): True to include archives series
326
327        Return:
328            OrderedDict:
329                key: series name
330                value: Series with idnum, name and desc filled out
331        """
332        sdict = OrderedDict()
333        for ser in self._get_series_list(include_archived):
334            sdict[ser.name] = ser
335        return sdict
336
337    def series_get_version_list(self, series_idnum):
338        """Get a list of the versions available for a series
339
340        Args:
341            series_idnum (int): ID of series to look up
342
343        Return:
344            str: List of versions, which may be empty if the series is in the
345                process of being added
346        """
347        res = self.execute('SELECT version FROM ser_ver WHERE series_id = ?',
348                           (series_idnum,))
349        return [x[0] for x in res.fetchall()]
350
351    def series_get_max_version(self, series_idnum):
352        """Get the highest version number available for a series
353
354        Args:
355            series_idnum (int): ID of series to look up
356
357        Return:
358            int: Maximum version number
359        """
360        res = self.execute(
361            'SELECT MAX(version) FROM ser_ver WHERE series_id = ?',
362            (series_idnum,))
363        return res.fetchall()[0][0]
364
365    def series_get_all_max_versions(self):
366        """Find the latest version of all series
367
368        Return: list of:
369            int: ser_ver ID
370            int: series ID
371            int: Maximum version
372        """
373        res = self.execute(
374            'SELECT id, series_id, MAX(version) FROM ser_ver '
375            'GROUP BY series_id')
376        return res.fetchall()
377
378    def series_add(self, name, desc):
379        """Add a new series record
380
381        The new record is set to not archived
382
383        Args:
384            name (str): Series name
385            desc (str): Series description
386
387        Return:
388            int: ID num of the new series record
389        """
390        self.execute(
391            'INSERT INTO series (name, desc, archived) '
392            f"VALUES ('{name}', '{desc}', 0)")
393        return self.lastrowid()
394
395    def series_remove(self, idnum):
396        """Remove a series from the database
397
398        The series must exist
399
400        Args:
401            idnum (int): ID num of series to remove
402        """
403        self.execute('DELETE FROM series WHERE id = ?', (idnum,))
404        assert self.rowcount() == 1
405
406    def series_remove_by_name(self, name):
407        """Remove a series from the database
408
409        Args:
410            name (str): Name of series to remove
411
412        Raises:
413            ValueError: Series does not exist (database is rolled back)
414        """
415        self.execute('DELETE FROM series WHERE name = ?', (name,))
416        if self.rowcount() != 1:
417            self.rollback()
418            raise ValueError(f"No such series '{name}'")
419
420    def series_set_archived(self, series_idnum, archived):
421        """Update archive flag for a series
422
423        Args:
424            series_idnum (int): ID num of the series
425            archived (bool): Whether to mark the series as archived or
426                unarchived
427        """
428        self.execute(
429            'UPDATE series SET archived = ? WHERE id = ?',
430            (archived, series_idnum))
431
432    def series_set_name(self, series_idnum, name):
433        """Update name for a series
434
435        Args:
436            series_idnum (int): ID num of the series
437            name (str): new name to use
438        """
439        self.execute(
440            'UPDATE series SET name = ? WHERE id = ?', (name, series_idnum))
441
442    # ser_ver functions
443
444    def ser_ver_get_link(self, series_idnum, version):
445        """Get the link for a series version
446
447        Args:
448            series_idnum (int): ID num of the series
449            version (int): Version number to search for
450
451        Return:
452            str: Patchwork link as a string, e.g. '12325', or None if none
453
454        Raises:
455            ValueError: Multiple matches are found
456        """
457        res = self.execute(
458            'SELECT link FROM ser_ver WHERE '
459            f"series_id = {series_idnum} AND version = '{version}'")
460        recs = res.fetchall()
461        if not recs:
462            return None
463        if len(recs) > 1:
464            raise ValueError('Expected one match, but multiple matches found')
465        return recs[0][0]
466
467    def ser_ver_set_link(self, series_idnum, version, link):
468        """Set the link for a series version
469
470        Args:
471            series_idnum (int): ID num of the series
472            version (int): Version number to search for
473            link (str): Patchwork link for the ser_ver
474
475        Return:
476            bool: True if the record was found and updated, else False
477        """
478        if link is None:
479            link = ''
480        self.execute(
481            'UPDATE ser_ver SET link = ? WHERE series_id = ? AND version = ?',
482            (str(link), series_idnum, version))
483        return self.rowcount() != 0
484
485    def ser_ver_set_info(self, info):
486        """Set the info for a series version
487
488        Args:
489            info (SER_VER): Info to set. Only two options are supported:
490                1: svid,cover_id,cover_num_comments,name
491                2: svid,name
492
493        Return:
494            bool: True if the record was found and updated, else False
495        """
496        assert info.idnum is not None
497        if info.cover_id:
498            assert info.series_id is None
499            self.execute(
500                'UPDATE ser_ver SET cover_id = ?, cover_num_comments = ?, '
501                'name = ? WHERE id = ?',
502                (info.cover_id, info.cover_num_comments, info.name,
503                 info.idnum))
504        else:
505            assert not info.cover_id
506            assert not info.cover_num_comments
507            assert not info.series_id
508            assert not info.version
509            assert not info.link
510            self.execute('UPDATE ser_ver SET name = ? WHERE id = ?',
511                         (info.name, info.idnum))
512
513        return self.rowcount() != 0
514
515    def ser_ver_set_version(self, svid, version):
516        """Sets the version for a ser_ver record
517
518        Args:
519            svid (int): Record ID to update
520            version (int): Version number to add
521
522        Raises:
523            ValueError: svid was not found
524        """
525        self.execute(
526            'UPDATE ser_ver SET version = ? WHERE id = ?', (version, svid))
527        if self.rowcount() != 1:
528            raise ValueError(f'No ser_ver updated (svid {svid})')
529
530    def ser_ver_set_archive_tag(self, svid, tag):
531        """Sets the archive tag for a ser_ver record
532
533        Args:
534            svid (int): Record ID to update
535            tag (tag): Tag to add
536
537        Raises:
538            ValueError: svid was not found
539        """
540        self.execute(
541            'UPDATE ser_ver SET archive_tag = ? WHERE id = ?', (tag, svid))
542        if self.rowcount() != 1:
543            raise ValueError(f'No ser_ver updated (svid {svid})')
544
545    def ser_ver_add(self, series_idnum, version, link=None):
546        """Add a new ser_ver record
547
548        Args:
549            series_idnum (int): ID num of the series which is getting a new
550                version
551            version (int): Version number to add
552            link (str): Patchwork link, or None if not known
553
554        Return:
555            int: ID num of the new ser_ver record
556        """
557        self.execute(
558            'INSERT INTO ser_ver (series_id, version, link) VALUES (?, ?, ?)',
559            (series_idnum, version, link))
560        return self.lastrowid()
561
562    def ser_ver_get_for_series(self, series_idnum, version=None):
563        """Get a list of ser_ver records for a given series ID
564
565        Args:
566            series_idnum (int): ID num of the series to search
567            version (int): Version number to search for, or None for all
568
569        Return:
570            SER_VER: Requested information
571
572        Raises:
573            ValueError: There is no matching idnum/version
574        """
575        base = ('SELECT id, series_id, version, link, cover_id, '
576                'cover_num_comments, name, archive_tag FROM ser_ver '
577                'WHERE series_id = ?')
578        if version:
579            res = self.execute(base + ' AND version = ?',
580                               (series_idnum, version))
581        else:
582            res = self.execute(base, (series_idnum,))
583        recs = res.fetchall()
584        if not recs:
585            raise ValueError(
586                f'No matching series for id {series_idnum} version {version}')
587        if version:
588            return SerVer(*recs[0])
589        return [SerVer(*x) for x in recs]
590
591    def ser_ver_get_ids_for_series(self, series_idnum, version=None):
592        """Get a list of ser_ver records for a given series ID
593
594        Args:
595            series_idnum (int): ID num of the series to search
596            version (int): Version number to search for, or None for all
597
598        Return:
599            list of int: List of svids for the matching records
600        """
601        if version:
602            res = self.execute(
603                'SELECT id FROM ser_ver WHERE series_id = ? AND version = ?',
604                (series_idnum, version))
605        else:
606            res = self.execute(
607                'SELECT id FROM ser_ver WHERE series_id = ?', (series_idnum,))
608        return list(res.fetchall()[0])
609
610    def ser_ver_get_list(self):
611        """Get a list of patchwork entries from the database
612
613        Return:
614            list of SER_VER
615        """
616        res = self.execute(
617            'SELECT id, series_id, version, link, cover_id, '
618            'cover_num_comments, name, archive_tag FROM ser_ver')
619        items = res.fetchall()
620        return [SerVer(*x) for x in items]
621
622    def ser_ver_remove(self, series_idnum, version=None, remove_pcommits=True,
623                       remove_series=True):
624        """Delete a ser_ver record
625
626        Removes the record which has the given series ID num and version
627
628        Args:
629            series_idnum (int): ID num of the series
630            version (int): Version number, or None to remove all versions
631            remove_pcommits (bool): True to remove associated pcommits too
632            remove_series (bool): True to remove the series if versions is None
633        """
634        if remove_pcommits:
635            # Figure out svids to delete
636            svids = self.ser_ver_get_ids_for_series(series_idnum, version)
637
638            self.pcommit_delete_list(svids)
639
640        if version:
641            self.execute(
642                'DELETE FROM ser_ver WHERE series_id = ? AND version = ?',
643                (series_idnum, version))
644        else:
645            self.execute(
646                'DELETE FROM ser_ver WHERE series_id = ?',
647                (series_idnum,))
648        if not version and remove_series:
649            self.series_remove(series_idnum)
650
651    # pcommit functions
652
653    def pcommit_get_list(self, find_svid=None):
654        """Get a dict of pcommits entries from the database
655
656        Args:
657            find_svid (int): If not None, finds the records associated with a
658                particular series and version; otherwise returns all records
659
660        Return:
661            list of PCOMMIT: pcommit records
662        """
663        query = ('SELECT id, seq, subject, svid, change_id, state, patch_id, '
664                 'num_comments FROM pcommit')
665        if find_svid is not None:
666            query += f' WHERE svid = {find_svid}'
667        res = self.execute(query)
668        return [Pcommit(*rec) for rec in res.fetchall()]
669
670    def pcommit_add_list(self, svid, pcommits):
671        """Add records to the pcommit table
672
673        Args:
674            svid (int): ser_ver ID num
675            pcommits (list of PCOMMIT): Only seq, subject, change_id are
676                uses; svid comes from the argument passed in and the others
677                are assumed to be obtained from patchwork later
678        """
679        for pcm in pcommits:
680            self.execute(
681                'INSERT INTO pcommit (svid, seq, subject, change_id) VALUES '
682                '(?, ?, ?, ?)', (svid, pcm.seq, pcm.subject, pcm.change_id))
683
684    def pcommit_delete(self, svid):
685        """Delete pcommit records for a given ser_ver ID
686
687        Args_:
688            svid (int): ser_ver ID num of records to delete
689        """
690        self.execute('DELETE FROM pcommit WHERE svid = ?', (svid,))
691
692    def pcommit_delete_list(self, svid_list):
693        """Delete pcommit records for a given set of ser_ver IDs
694
695        Args_:
696            svid (list int): ser_ver ID nums of records to delete
697        """
698        vals = ', '.join([str(x) for x in svid_list])
699        self.execute('DELETE FROM pcommit WHERE svid IN (?)', (vals,))
700
701    def pcommit_update(self, pcm):
702        """Update a pcommit record
703
704        Args:
705            pcm (PCOMMIT): Information to write; only the idnum, state,
706                patch_id and num_comments are used
707
708        Return:
709            True if the data was written
710        """
711        self.execute(
712            'UPDATE pcommit SET '
713            'patch_id = ?, state = ?, num_comments = ? WHERE id = ?',
714            (pcm.patch_id, pcm.state, pcm.num_comments, pcm.idnum))
715        return self.rowcount() > 0
716
717    # upstream functions
718
719    def upstream_add(self, name, url):
720        """Add a new upstream record
721
722        Args:
723            name (str): Name of the tree
724            url (str): URL for the tree
725
726        Raises:
727            ValueError if the name already exists in the database
728        """
729        try:
730            self.execute(
731                'INSERT INTO upstream (name, url) VALUES (?, ?)', (name, url))
732        except sqlite3.IntegrityError as exc:
733            if 'UNIQUE constraint failed: upstream.name' in str(exc):
734                raise ValueError(f"Upstream '{name}' already exists") from exc
735
736    def upstream_set_default(self, name):
737        """Mark (only) the given upstream as the default
738
739        Args:
740            name (str): Name of the upstream remote to set as default, or None
741
742        Raises:
743            ValueError if more than one name matches (should not happen);
744                database is rolled back
745        """
746        self.execute("UPDATE upstream SET is_default = 0")
747        if name is not None:
748            self.execute(
749                'UPDATE upstream SET is_default = 1 WHERE name = ?', (name,))
750            if self.rowcount() != 1:
751                self.rollback()
752                raise ValueError(f"No such upstream '{name}'")
753
754    def upstream_get_default(self):
755        """Get the name of the default upstream
756
757        Return:
758            str: Default-upstream name, or None if there is no default
759        """
760        res = self.execute(
761            "SELECT name FROM upstream WHERE is_default = 1")
762        recs = res.fetchall()
763        if len(recs) != 1:
764            return None
765        return recs[0][0]
766
767    def upstream_delete(self, name):
768        """Delete an upstream target
769
770        Args:
771            name (str): Name of the upstream remote to delete
772
773        Raises:
774            ValueError: Upstream does not exist (database is rolled back)
775        """
776        self.execute(f"DELETE FROM upstream WHERE name = '{name}'")
777        if self.rowcount() != 1:
778            self.rollback()
779            raise ValueError(f"No such upstream '{name}'")
780
781    def upstream_get_dict(self):
782        """Get a list of upstream entries from the database
783
784        Return:
785            OrderedDict:
786                key (str): upstream name
787                value (str): url
788        """
789        res = self.execute('SELECT name, url, is_default FROM upstream')
790        udict = OrderedDict()
791        for name, url, is_default in res.fetchall():
792            udict[name] = url, is_default
793        return udict
794
795    # settings functions
796
797    def settings_update(self, name, proj_id, link_name):
798        """Set the patchwork settings of the project
799
800        Args:
801            name (str): Name of the project to use in patchwork
802            proj_id (int): Project ID for the project
803            link_name (str): Link name for the project
804        """
805        self.execute('DELETE FROM settings')
806        self.execute(
807                'INSERT INTO settings (name, proj_id, link_name) '
808                'VALUES (?, ?, ?)', (name, proj_id, link_name))
809
810    def settings_get(self):
811        """Get the patchwork settings of the project
812
813        Returns:
814            tuple or None if there are no settings:
815                name (str): Project name, e.g. 'U-Boot'
816                proj_id (int): Patchworks project ID for this project
817                link_name (str): Patchwork's link-name for the project
818        """
819        res = self.execute("SELECT name, proj_id, link_name FROM settings")
820        recs = res.fetchall()
821        if len(recs) != 1:
822            return None
823        return recs[0]
824