1# SPDX-License-Identifier: GPL-2.0+ 2# 3# Copyright 2025 Simon Glass <sjg@chromium.org> 4# 5"""Handles the patman database 6 7This uses sqlite3 with a local file. 8 9To adjsut the schema, increment LATEST, create a migrate_to_v<x>() function 10and write some code in migrate_to() to call it. 11""" 12 13from collections import namedtuple, OrderedDict 14import os 15import sqlite3 16 17from u_boot_pylib import tools 18from u_boot_pylib import tout 19from patman.series import Series 20 21# Schema version (version 0 means there is no database yet) 22LATEST = 4 23 24# Information about a series/version record 25SerVer = namedtuple( 26 'SER_VER', 27 'idnum,series_id,version,link,cover_id,cover_num_comments,name,' 28 'archive_tag') 29 30# Record from the pcommit table: 31# idnum (int): record ID 32# seq (int): Patch sequence in series (0 is first) 33# subject (str): patch subject 34# svid (int): ID of series/version record in ser_ver table 35# change_id (str): Change-ID value 36# state (str): Current status in patchwork 37# patch_id (int): Patchwork's patch ID for this patch 38# num_comments (int): Number of comments attached to the commit 39Pcommit = namedtuple( 40 'PCOMMIT', 41 'idnum,seq,subject,svid,change_id,state,patch_id,num_comments') 42 43 44class Database: 45 """Database of information used by patman""" 46 47 # dict of databases: 48 # key: filename 49 # value: Database object 50 instances = {} 51 52 def __init__(self, db_path): 53 """Set up a new database object 54 55 Args: 56 db_path (str): Path to the database 57 """ 58 if db_path in Database.instances: 59 # Two connections to the database can cause: 60 # sqlite3.OperationalError: database is locked 61 raise ValueError(f"There is already a database for '{db_path}'") 62 self.con = None 63 self.cur = None 64 self.db_path = db_path 65 self.is_open = False 66 Database.instances[db_path] = self 67 68 @staticmethod 69 def get_instance(db_path): 70 """Get the database instance for a path 71 72 This is provides to ensure that different callers can obtain the 73 same database object when accessing the same database file. 74 75 Args: 76 db_path (str): Path to the database 77 78 Return: 79 Database: Database instance, which is created if necessary 80 """ 81 db = Database.instances.get(db_path) 82 if db: 83 return db, False 84 return Database(db_path), True 85 86 def start(self): 87 """Open the database read for use, migrate to latest schema""" 88 self.open_it() 89 self.migrate_to(LATEST) 90 91 def open_it(self): 92 """Open the database, creating it if necessary""" 93 if self.is_open: 94 raise ValueError('Already open') 95 if not os.path.exists(self.db_path): 96 tout.warning(f'Creating new database {self.db_path}') 97 self.con = sqlite3.connect(self.db_path) 98 self.cur = self.con.cursor() 99 self.is_open = True 100 101 def close(self): 102 """Close the database""" 103 if not self.is_open: 104 raise ValueError('Already closed') 105 self.con.close() 106 self.cur = None 107 self.con = None 108 self.is_open = False 109 110 def create_v1(self): 111 """Create a database with the v1 schema""" 112 self.cur.execute( 113 'CREATE TABLE series (id INTEGER PRIMARY KEY AUTOINCREMENT,' 114 'name UNIQUE, desc, archived BIT)') 115 116 # Provides a series_id/version pair, which is used to refer to a 117 # particular series version sent to patchwork. This stores the link 118 # to patchwork 119 self.cur.execute( 120 'CREATE TABLE ser_ver (id INTEGER PRIMARY KEY AUTOINCREMENT,' 121 'series_id INTEGER, version INTEGER, link,' 122 'FOREIGN KEY (series_id) REFERENCES series (id))') 123 124 self.cur.execute( 125 'CREATE TABLE upstream (name UNIQUE, url, is_default BIT)') 126 127 # change_id is the Change-Id 128 # patch_id is the ID of the patch on the patchwork server 129 self.cur.execute( 130 'CREATE TABLE pcommit (id INTEGER PRIMARY KEY AUTOINCREMENT,' 131 'svid INTEGER, seq INTEGER, subject, patch_id INTEGER, ' 132 'change_id, state, num_comments INTEGER, ' 133 'FOREIGN KEY (svid) REFERENCES ser_ver (id))') 134 135 self.cur.execute( 136 'CREATE TABLE settings (name UNIQUE, proj_id INT, link_name)') 137 138 def _migrate_to_v2(self): 139 """Add a schema_version table""" 140 self.cur.execute('CREATE TABLE schema_version (version INTEGER)') 141 142 def _migrate_to_v3(self): 143 """Store the number of cover-letter comments in the schema""" 144 self.cur.execute('ALTER TABLE ser_ver ADD COLUMN cover_id') 145 self.cur.execute('ALTER TABLE ser_ver ADD COLUMN cover_num_comments ' 146 'INTEGER') 147 self.cur.execute('ALTER TABLE ser_ver ADD COLUMN name') 148 149 def _migrate_to_v4(self): 150 """Add an archive tag for each ser_ver""" 151 self.cur.execute('ALTER TABLE ser_ver ADD COLUMN archive_tag') 152 153 def migrate_to(self, dest_version): 154 """Migrate the database to the selected version 155 156 Args: 157 dest_version (int): Version to migrate to 158 """ 159 while True: 160 version = self.get_schema_version() 161 if version == dest_version: 162 break 163 164 self.close() 165 tools.write_file(f'{self.db_path}old.v{version}', 166 tools.read_file(self.db_path)) 167 168 version += 1 169 tout.info(f'Update database to v{version}') 170 self.open_it() 171 if version == 1: 172 self.create_v1() 173 elif version == 2: 174 self._migrate_to_v2() 175 elif version == 3: 176 self._migrate_to_v3() 177 elif version == 4: 178 self._migrate_to_v4() 179 180 # Save the new version if we have a schema_version table 181 if version > 1: 182 self.cur.execute('DELETE FROM schema_version') 183 self.cur.execute( 184 'INSERT INTO schema_version (version) VALUES (?)', 185 (version,)) 186 self.commit() 187 188 def get_schema_version(self): 189 """Get the version of the database's schema 190 191 Return: 192 int: Database version, 0 means there is no data; anything less than 193 LATEST means the schema is out of date and must be updated 194 """ 195 # If there is no database at all, assume v0 196 version = 0 197 try: 198 self.cur.execute('SELECT name FROM series') 199 except sqlite3.OperationalError: 200 return 0 201 202 # If there is no schema, assume v1 203 try: 204 self.cur.execute('SELECT version FROM schema_version') 205 version = self.cur.fetchone()[0] 206 except sqlite3.OperationalError: 207 return 1 208 return version 209 210 def execute(self, query, parameters=()): 211 """Execute a database query 212 213 Args: 214 query (str): Query string 215 parameters (list of values): Parameters to pass 216 217 Return: 218 219 """ 220 return self.cur.execute(query, parameters) 221 222 def commit(self): 223 """Commit changes to the database""" 224 self.con.commit() 225 226 def rollback(self): 227 """Roll back changes to the database""" 228 self.con.rollback() 229 230 def lastrowid(self): 231 """Get the last row-ID reported by the database 232 233 Return: 234 int: Value for lastrowid 235 """ 236 return self.cur.lastrowid 237 238 def rowcount(self): 239 """Get the row-count reported by the database 240 241 Return: 242 int: Value for rowcount 243 """ 244 return self.cur.rowcount 245 246 def _get_series_list(self, include_archived): 247 """Get a list of Series objects from the database 248 249 Args: 250 include_archived (bool): True to include archives series 251 252 Return: 253 list of Series 254 """ 255 res = self.execute( 256 'SELECT id, name, desc FROM series ' + 257 ('WHERE archived = 0' if not include_archived else '')) 258 return [Series.from_fields(idnum=idnum, name=name, desc=desc) 259 for idnum, name, desc in res.fetchall()] 260 261 # series functions 262 263 def series_get_dict_by_id(self, include_archived=False): 264 """Get a dict of Series objects from the database 265 266 Args: 267 include_archived (bool): True to include archives series 268 269 Return: 270 OrderedDict: 271 key: series ID 272 value: Series with idnum, name and desc filled out 273 """ 274 sdict = OrderedDict() 275 for ser in self._get_series_list(include_archived): 276 sdict[ser.idnum] = ser 277 return sdict 278 279 def series_find_by_name(self, name, include_archived=False): 280 """Find a series and return its details 281 282 Args: 283 name (str): Name to search for 284 include_archived (bool): True to include archives series 285 286 Returns: 287 idnum, or None if not found 288 """ 289 res = self.execute( 290 'SELECT id FROM series WHERE name = ?' + 291 ('AND archived = 0' if not include_archived else ''), (name,)) 292 recs = res.fetchall() 293 294 # This shouldn't happen 295 assert len(recs) <= 1, 'Expected one match, but multiple found' 296 297 if len(recs) != 1: 298 return None 299 return recs[0][0] 300 301 def series_get_info(self, idnum): 302 """Get information for a series from the database 303 304 Args: 305 idnum (int): Series ID to look up 306 307 Return: tuple: 308 str: Series name 309 str: Series description 310 311 Raises: 312 ValueError: Series is not found 313 """ 314 res = self.execute('SELECT name, desc FROM series WHERE id = ?', 315 (idnum,)) 316 recs = res.fetchall() 317 if len(recs) != 1: 318 raise ValueError(f'No series found (id {idnum} len {len(recs)})') 319 return recs[0] 320 321 def series_get_dict(self, include_archived=False): 322 """Get a dict of Series objects from the database 323 324 Args: 325 include_archived (bool): True to include archives series 326 327 Return: 328 OrderedDict: 329 key: series name 330 value: Series with idnum, name and desc filled out 331 """ 332 sdict = OrderedDict() 333 for ser in self._get_series_list(include_archived): 334 sdict[ser.name] = ser 335 return sdict 336 337 def series_get_version_list(self, series_idnum): 338 """Get a list of the versions available for a series 339 340 Args: 341 series_idnum (int): ID of series to look up 342 343 Return: 344 str: List of versions, which may be empty if the series is in the 345 process of being added 346 """ 347 res = self.execute('SELECT version FROM ser_ver WHERE series_id = ?', 348 (series_idnum,)) 349 return [x[0] for x in res.fetchall()] 350 351 def series_get_max_version(self, series_idnum): 352 """Get the highest version number available for a series 353 354 Args: 355 series_idnum (int): ID of series to look up 356 357 Return: 358 int: Maximum version number 359 """ 360 res = self.execute( 361 'SELECT MAX(version) FROM ser_ver WHERE series_id = ?', 362 (series_idnum,)) 363 return res.fetchall()[0][0] 364 365 def series_get_all_max_versions(self): 366 """Find the latest version of all series 367 368 Return: list of: 369 int: ser_ver ID 370 int: series ID 371 int: Maximum version 372 """ 373 res = self.execute( 374 'SELECT id, series_id, MAX(version) FROM ser_ver ' 375 'GROUP BY series_id') 376 return res.fetchall() 377 378 def series_add(self, name, desc): 379 """Add a new series record 380 381 The new record is set to not archived 382 383 Args: 384 name (str): Series name 385 desc (str): Series description 386 387 Return: 388 int: ID num of the new series record 389 """ 390 self.execute( 391 'INSERT INTO series (name, desc, archived) ' 392 f"VALUES ('{name}', '{desc}', 0)") 393 return self.lastrowid() 394 395 def series_remove(self, idnum): 396 """Remove a series from the database 397 398 The series must exist 399 400 Args: 401 idnum (int): ID num of series to remove 402 """ 403 self.execute('DELETE FROM series WHERE id = ?', (idnum,)) 404 assert self.rowcount() == 1 405 406 def series_remove_by_name(self, name): 407 """Remove a series from the database 408 409 Args: 410 name (str): Name of series to remove 411 412 Raises: 413 ValueError: Series does not exist (database is rolled back) 414 """ 415 self.execute('DELETE FROM series WHERE name = ?', (name,)) 416 if self.rowcount() != 1: 417 self.rollback() 418 raise ValueError(f"No such series '{name}'") 419 420 def series_set_archived(self, series_idnum, archived): 421 """Update archive flag for a series 422 423 Args: 424 series_idnum (int): ID num of the series 425 archived (bool): Whether to mark the series as archived or 426 unarchived 427 """ 428 self.execute( 429 'UPDATE series SET archived = ? WHERE id = ?', 430 (archived, series_idnum)) 431 432 def series_set_name(self, series_idnum, name): 433 """Update name for a series 434 435 Args: 436 series_idnum (int): ID num of the series 437 name (str): new name to use 438 """ 439 self.execute( 440 'UPDATE series SET name = ? WHERE id = ?', (name, series_idnum)) 441 442 # ser_ver functions 443 444 def ser_ver_get_link(self, series_idnum, version): 445 """Get the link for a series version 446 447 Args: 448 series_idnum (int): ID num of the series 449 version (int): Version number to search for 450 451 Return: 452 str: Patchwork link as a string, e.g. '12325', or None if none 453 454 Raises: 455 ValueError: Multiple matches are found 456 """ 457 res = self.execute( 458 'SELECT link FROM ser_ver WHERE ' 459 f"series_id = {series_idnum} AND version = '{version}'") 460 recs = res.fetchall() 461 if not recs: 462 return None 463 if len(recs) > 1: 464 raise ValueError('Expected one match, but multiple matches found') 465 return recs[0][0] 466 467 def ser_ver_set_link(self, series_idnum, version, link): 468 """Set the link for a series version 469 470 Args: 471 series_idnum (int): ID num of the series 472 version (int): Version number to search for 473 link (str): Patchwork link for the ser_ver 474 475 Return: 476 bool: True if the record was found and updated, else False 477 """ 478 if link is None: 479 link = '' 480 self.execute( 481 'UPDATE ser_ver SET link = ? WHERE series_id = ? AND version = ?', 482 (str(link), series_idnum, version)) 483 return self.rowcount() != 0 484 485 def ser_ver_set_info(self, info): 486 """Set the info for a series version 487 488 Args: 489 info (SER_VER): Info to set. Only two options are supported: 490 1: svid,cover_id,cover_num_comments,name 491 2: svid,name 492 493 Return: 494 bool: True if the record was found and updated, else False 495 """ 496 assert info.idnum is not None 497 if info.cover_id: 498 assert info.series_id is None 499 self.execute( 500 'UPDATE ser_ver SET cover_id = ?, cover_num_comments = ?, ' 501 'name = ? WHERE id = ?', 502 (info.cover_id, info.cover_num_comments, info.name, 503 info.idnum)) 504 else: 505 assert not info.cover_id 506 assert not info.cover_num_comments 507 assert not info.series_id 508 assert not info.version 509 assert not info.link 510 self.execute('UPDATE ser_ver SET name = ? WHERE id = ?', 511 (info.name, info.idnum)) 512 513 return self.rowcount() != 0 514 515 def ser_ver_set_version(self, svid, version): 516 """Sets the version for a ser_ver record 517 518 Args: 519 svid (int): Record ID to update 520 version (int): Version number to add 521 522 Raises: 523 ValueError: svid was not found 524 """ 525 self.execute( 526 'UPDATE ser_ver SET version = ? WHERE id = ?', (version, svid)) 527 if self.rowcount() != 1: 528 raise ValueError(f'No ser_ver updated (svid {svid})') 529 530 def ser_ver_set_archive_tag(self, svid, tag): 531 """Sets the archive tag for a ser_ver record 532 533 Args: 534 svid (int): Record ID to update 535 tag (tag): Tag to add 536 537 Raises: 538 ValueError: svid was not found 539 """ 540 self.execute( 541 'UPDATE ser_ver SET archive_tag = ? WHERE id = ?', (tag, svid)) 542 if self.rowcount() != 1: 543 raise ValueError(f'No ser_ver updated (svid {svid})') 544 545 def ser_ver_add(self, series_idnum, version, link=None): 546 """Add a new ser_ver record 547 548 Args: 549 series_idnum (int): ID num of the series which is getting a new 550 version 551 version (int): Version number to add 552 link (str): Patchwork link, or None if not known 553 554 Return: 555 int: ID num of the new ser_ver record 556 """ 557 self.execute( 558 'INSERT INTO ser_ver (series_id, version, link) VALUES (?, ?, ?)', 559 (series_idnum, version, link)) 560 return self.lastrowid() 561 562 def ser_ver_get_for_series(self, series_idnum, version=None): 563 """Get a list of ser_ver records for a given series ID 564 565 Args: 566 series_idnum (int): ID num of the series to search 567 version (int): Version number to search for, or None for all 568 569 Return: 570 SER_VER: Requested information 571 572 Raises: 573 ValueError: There is no matching idnum/version 574 """ 575 base = ('SELECT id, series_id, version, link, cover_id, ' 576 'cover_num_comments, name, archive_tag FROM ser_ver ' 577 'WHERE series_id = ?') 578 if version: 579 res = self.execute(base + ' AND version = ?', 580 (series_idnum, version)) 581 else: 582 res = self.execute(base, (series_idnum,)) 583 recs = res.fetchall() 584 if not recs: 585 raise ValueError( 586 f'No matching series for id {series_idnum} version {version}') 587 if version: 588 return SerVer(*recs[0]) 589 return [SerVer(*x) for x in recs] 590 591 def ser_ver_get_ids_for_series(self, series_idnum, version=None): 592 """Get a list of ser_ver records for a given series ID 593 594 Args: 595 series_idnum (int): ID num of the series to search 596 version (int): Version number to search for, or None for all 597 598 Return: 599 list of int: List of svids for the matching records 600 """ 601 if version: 602 res = self.execute( 603 'SELECT id FROM ser_ver WHERE series_id = ? AND version = ?', 604 (series_idnum, version)) 605 else: 606 res = self.execute( 607 'SELECT id FROM ser_ver WHERE series_id = ?', (series_idnum,)) 608 return list(res.fetchall()[0]) 609 610 def ser_ver_get_list(self): 611 """Get a list of patchwork entries from the database 612 613 Return: 614 list of SER_VER 615 """ 616 res = self.execute( 617 'SELECT id, series_id, version, link, cover_id, ' 618 'cover_num_comments, name, archive_tag FROM ser_ver') 619 items = res.fetchall() 620 return [SerVer(*x) for x in items] 621 622 def ser_ver_remove(self, series_idnum, version=None, remove_pcommits=True, 623 remove_series=True): 624 """Delete a ser_ver record 625 626 Removes the record which has the given series ID num and version 627 628 Args: 629 series_idnum (int): ID num of the series 630 version (int): Version number, or None to remove all versions 631 remove_pcommits (bool): True to remove associated pcommits too 632 remove_series (bool): True to remove the series if versions is None 633 """ 634 if remove_pcommits: 635 # Figure out svids to delete 636 svids = self.ser_ver_get_ids_for_series(series_idnum, version) 637 638 self.pcommit_delete_list(svids) 639 640 if version: 641 self.execute( 642 'DELETE FROM ser_ver WHERE series_id = ? AND version = ?', 643 (series_idnum, version)) 644 else: 645 self.execute( 646 'DELETE FROM ser_ver WHERE series_id = ?', 647 (series_idnum,)) 648 if not version and remove_series: 649 self.series_remove(series_idnum) 650 651 # pcommit functions 652 653 def pcommit_get_list(self, find_svid=None): 654 """Get a dict of pcommits entries from the database 655 656 Args: 657 find_svid (int): If not None, finds the records associated with a 658 particular series and version; otherwise returns all records 659 660 Return: 661 list of PCOMMIT: pcommit records 662 """ 663 query = ('SELECT id, seq, subject, svid, change_id, state, patch_id, ' 664 'num_comments FROM pcommit') 665 if find_svid is not None: 666 query += f' WHERE svid = {find_svid}' 667 res = self.execute(query) 668 return [Pcommit(*rec) for rec in res.fetchall()] 669 670 def pcommit_add_list(self, svid, pcommits): 671 """Add records to the pcommit table 672 673 Args: 674 svid (int): ser_ver ID num 675 pcommits (list of PCOMMIT): Only seq, subject, change_id are 676 uses; svid comes from the argument passed in and the others 677 are assumed to be obtained from patchwork later 678 """ 679 for pcm in pcommits: 680 self.execute( 681 'INSERT INTO pcommit (svid, seq, subject, change_id) VALUES ' 682 '(?, ?, ?, ?)', (svid, pcm.seq, pcm.subject, pcm.change_id)) 683 684 def pcommit_delete(self, svid): 685 """Delete pcommit records for a given ser_ver ID 686 687 Args_: 688 svid (int): ser_ver ID num of records to delete 689 """ 690 self.execute('DELETE FROM pcommit WHERE svid = ?', (svid,)) 691 692 def pcommit_delete_list(self, svid_list): 693 """Delete pcommit records for a given set of ser_ver IDs 694 695 Args_: 696 svid (list int): ser_ver ID nums of records to delete 697 """ 698 vals = ', '.join([str(x) for x in svid_list]) 699 self.execute('DELETE FROM pcommit WHERE svid IN (?)', (vals,)) 700 701 def pcommit_update(self, pcm): 702 """Update a pcommit record 703 704 Args: 705 pcm (PCOMMIT): Information to write; only the idnum, state, 706 patch_id and num_comments are used 707 708 Return: 709 True if the data was written 710 """ 711 self.execute( 712 'UPDATE pcommit SET ' 713 'patch_id = ?, state = ?, num_comments = ? WHERE id = ?', 714 (pcm.patch_id, pcm.state, pcm.num_comments, pcm.idnum)) 715 return self.rowcount() > 0 716 717 # upstream functions 718 719 def upstream_add(self, name, url): 720 """Add a new upstream record 721 722 Args: 723 name (str): Name of the tree 724 url (str): URL for the tree 725 726 Raises: 727 ValueError if the name already exists in the database 728 """ 729 try: 730 self.execute( 731 'INSERT INTO upstream (name, url) VALUES (?, ?)', (name, url)) 732 except sqlite3.IntegrityError as exc: 733 if 'UNIQUE constraint failed: upstream.name' in str(exc): 734 raise ValueError(f"Upstream '{name}' already exists") from exc 735 736 def upstream_set_default(self, name): 737 """Mark (only) the given upstream as the default 738 739 Args: 740 name (str): Name of the upstream remote to set as default, or None 741 742 Raises: 743 ValueError if more than one name matches (should not happen); 744 database is rolled back 745 """ 746 self.execute("UPDATE upstream SET is_default = 0") 747 if name is not None: 748 self.execute( 749 'UPDATE upstream SET is_default = 1 WHERE name = ?', (name,)) 750 if self.rowcount() != 1: 751 self.rollback() 752 raise ValueError(f"No such upstream '{name}'") 753 754 def upstream_get_default(self): 755 """Get the name of the default upstream 756 757 Return: 758 str: Default-upstream name, or None if there is no default 759 """ 760 res = self.execute( 761 "SELECT name FROM upstream WHERE is_default = 1") 762 recs = res.fetchall() 763 if len(recs) != 1: 764 return None 765 return recs[0][0] 766 767 def upstream_delete(self, name): 768 """Delete an upstream target 769 770 Args: 771 name (str): Name of the upstream remote to delete 772 773 Raises: 774 ValueError: Upstream does not exist (database is rolled back) 775 """ 776 self.execute(f"DELETE FROM upstream WHERE name = '{name}'") 777 if self.rowcount() != 1: 778 self.rollback() 779 raise ValueError(f"No such upstream '{name}'") 780 781 def upstream_get_dict(self): 782 """Get a list of upstream entries from the database 783 784 Return: 785 OrderedDict: 786 key (str): upstream name 787 value (str): url 788 """ 789 res = self.execute('SELECT name, url, is_default FROM upstream') 790 udict = OrderedDict() 791 for name, url, is_default in res.fetchall(): 792 udict[name] = url, is_default 793 return udict 794 795 # settings functions 796 797 def settings_update(self, name, proj_id, link_name): 798 """Set the patchwork settings of the project 799 800 Args: 801 name (str): Name of the project to use in patchwork 802 proj_id (int): Project ID for the project 803 link_name (str): Link name for the project 804 """ 805 self.execute('DELETE FROM settings') 806 self.execute( 807 'INSERT INTO settings (name, proj_id, link_name) ' 808 'VALUES (?, ?, ?)', (name, proj_id, link_name)) 809 810 def settings_get(self): 811 """Get the patchwork settings of the project 812 813 Returns: 814 tuple or None if there are no settings: 815 name (str): Project name, e.g. 'U-Boot' 816 proj_id (int): Patchworks project ID for this project 817 link_name (str): Patchwork's link-name for the project 818 """ 819 res = self.execute("SELECT name, proj_id, link_name FROM settings") 820 recs = res.fetchall() 821 if len(recs) != 1: 822 return None 823 return recs[0] 824