1#!/usr/bin/env python3
2
3# A utility for installing LK toolchains.
4
5from __future__ import annotations
6
7import argparse
8import html.parser
9import io
10import os
11import pathlib
12import sys
13import tarfile
14import threading
15import urllib.request
16from typing import Self
17
18BASE_URL = "https://newos.org/toolchains"
19
20HOST_OS = os.uname().sysname
21HOST_CPU = os.uname().machine
22
23LK_ROOT = pathlib.Path(os.path.realpath(__file__)).parent.parent
24DEFAULT_TOOLCHAIN_DIR = LK_ROOT.joinpath("toolchain")
25
26TAR_EXT = ".tar.xz"
27
28
29def main() -> int:
30    parser = argparse.ArgumentParser(
31        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
32        description="Installs the matching LK toolchains from the official host, "
33        + BASE_URL,
34    )
35    parser.add_argument(
36        "--list",
37        help="just list the matching toolchains; don't download them",
38        action="store_true",
39    )
40    parser.add_argument(
41        "--prefix",
42        help="a toolchain prefix on which to match. If none are specified, all prefixes"
43        " will match",
44        nargs="*",
45    )
46    parser.add_argument(
47        "--version",
48        help='the exact toolchain version to match, or "latest" to specify only the '
49        'latest version, or "all" for all versions',
50        type=str,
51        default="latest",
52    )
53    parser.add_argument(
54        "--install-dir",
55        help="the directory at which to install the toolchains",
56        type=pathlib.Path,
57        default=DEFAULT_TOOLCHAIN_DIR,
58    )
59    parser.add_argument(
60        "--force",
61        help="whether to overwrite past installed versions of matching toolchains",
62        action="store_true",
63    )
64    parser.add_argument(
65        "--host-os",
66        help="the toolchains' host OS",
67        type=str,
68        default=HOST_OS,
69    )
70    parser.add_argument(
71        "--host-cpu",
72        help="the toolchains' host architecture",
73        type=str,
74        default=HOST_CPU,
75    )
76    args = parser.parse_args()
77
78    # Get the full list of remote toolchains available for the provided host.
79    response = urllib.request.urlopen(BASE_URL)
80    if response.status != 200:
81        print(f"Error accessing {BASE_URL}: {response.status}")
82        return 1
83    parser = RemoteToolchainHTMLParser(args.host_os, args.host_cpu)
84    parser.feed(response.read().decode("utf-8"))
85    toolchains = parser.toolchains
86
87    # Filter them given --prefix and --version selections.
88    toolchains.sort()
89    if args.prefix:
90        toolchains = [t for t in toolchains if t.prefix in args.prefix]
91    if args.version == "latest":
92        # Since we sorted lexicographically on (prefix, version tokens), to pick out the
93        # latest versions we need only iterate through and pick out the last entry for a
94        # given prefix.
95        toolchains = [
96            toolchains[i]
97            for i in range(len(toolchains))
98            if (
99                i == len(toolchains) - 1
100                or toolchains[i].prefix != toolchains[i + 1].prefix
101            )
102        ]
103    elif args.version != "all":
104        toolchains = [t for t in toolchains if t.version == args.version]
105
106    if not toolchains:
107        print("No matching toolchains")
108        return 0
109
110    if args.list:
111        print("Matching toolchains:")
112        for toolchain in toolchains:
113            print(toolchain.name)
114        return 0
115
116    # The download routine for a given toolchain, factored out for
117    # multithreading below.
118    def download(toolchain: RemoteToolchain) -> None:
119        response = urllib.request.urlopen(toolchain.url)
120        if response.status != 200:
121            print(f"Error while downloading {toolchain.name}: {response.status}")
122            return
123        with tarfile.open(fileobj=io.BytesIO(response.read()), mode="r:xz") as f:
124            f.extractall(path=args.install_dir, filter="data")
125
126    downloads = []
127    for toolchain in toolchains:
128        local = args.install_dir.joinpath(toolchain.name)
129        if local.exists() and not args.force:
130            print(
131                f"{toolchain.name} already installed; "
132                "skipping... (pass --force to overwrite)",
133            )
134            continue
135        print(f"Downloading {toolchain.name} to {local}...")
136        downloads.append(threading.Thread(target=download, args=(toolchain,)))
137        downloads[-1].start()
138
139    for thread in downloads:
140        thread.join()
141
142    return 0
143
144
145class RemoteToolchain:
146    def __init__(self, prefix: str, version: str, host_os: str, host_cpu: str) -> None:
147        self._prefix = prefix
148        self._version = [int(token) for token in version.split(".")]
149        self._host = f"{host_os}-{host_cpu}"
150
151    # Orders toolchains lexicographically on (prefix, version tokens).
152    def __lt__(self, other: Self) -> bool:
153        return self._prefix < other.prefix or (
154            self._prefix == other.prefix and self._version < other._version
155        )
156
157    @property
158    def prefix(self) -> str:
159        return self._prefix
160
161    @property
162    def version(self) -> str:
163        return ".".join(map(str, self._version))
164
165    @property
166    def name(self) -> str:
167        return f"{self._prefix}-{self.version}-{self._host}"
168
169    @property
170    def url(self) -> str:
171        return f"{BASE_URL}/{self.name}{TAR_EXT}"
172
173
174# A simple HTML parser for extracting the toolchain names found at BASE_URL.
175#
176# It expects toolchains to be available as hyperlinks on that page. Once the
177# HTML has been passed to feed(), the parsed toolchains will be accessible via
178# toolchains().
179class RemoteToolchainHTMLParser(html.parser.HTMLParser):
180    def __init__(self, host_os: str, host_cpu: str) -> None:
181        html.parser.HTMLParser.__init__(self)
182        self._toolchains = []
183        self._tags = []
184        self._host_os = host_os
185        self._host_cpu = host_cpu
186
187    # The parsed toolchains.
188    @property
189    def toolchains(self) -> list[RemoteToolchain]:
190        return self._toolchains
191
192    #
193    # The following methods implement the parsing, overriding those defined in
194    # the base class.
195    #
196
197    def handle_starttag(self, tag: str, _: str) -> None:
198        self._tags.append(tag)
199
200    def handle_endtag(self, _: str) -> None:
201        self._tags.pop()
202
203    def handle_data(self, data: str) -> None:
204        # Only process hyperlinks with tarball names.
205        if not self._tags or self._tags[-1] != "a" or not data.endswith(TAR_EXT):
206            return
207        tokens = data.removesuffix(TAR_EXT).split("-")
208        if len(tokens) != 5:
209            print(f"Warning: malformed toolchain name: {data}")
210            return
211        prefix = tokens[0] + "-" + tokens[1]
212        version = tokens[2]
213        host_os = tokens[3]
214        host_cpu = tokens[4]
215        if host_os != self._host_os or host_cpu != self._host_cpu:
216            return
217        self._toolchains.append(RemoteToolchain(prefix, version, host_os, host_cpu))
218
219
220if __name__ == "__main__":
221    sys.exit(main())
222