1#!/usr/bin/env python3
2
3#
4# Arm SCP/MCP Software
5# Copyright (c) 2021, Arm Limited and Contributors. All rights reserved.
6#
7# SPDX-License-Identifier: BSD-3-Clause
8#
9
10import argparse
11import colorama
12import difflib
13import os
14import sys
15
16import ruamel.yaml as yaml
17
18from io import StringIO
19
20
21def colorize(diff):
22    for line in diff:
23        if line.startswith("@"):
24            yield colorama.Fore.BLUE + line + colorama.Fore.RESET
25        elif line.startswith("+"):
26            yield colorama.Fore.GREEN + line + colorama.Fore.RESET
27        elif line.startswith("-"):
28            yield colorama.Fore.RED + line + colorama.Fore.RESET
29        else:
30            yield line
31
32
33def main():
34    # fmt: off
35
36    parser = argparse.ArgumentParser()
37
38    parser_common = argparse.ArgumentParser(add_help=False)
39    parser_common.add_argument("sources",
40                               help="list of source files to format or check",
41                               nargs="*")
42    parser_common.add_argument("-q", "--quiet", action="store_true",
43                               help="suppress output intended for humans")
44
45    subparsers = parser.add_subparsers(dest="command")
46
47    subparser_diff = subparsers.add_parser("diff",
48                                           help=("generate a unified diff " +
49                                                 "of required changes"),
50                                           parents=[parser_common])
51
52    subparser_diff.add_argument("--check",
53                                action="store_true",
54                                help=("exit with an error status code " +
55                                      "if changes are required"))
56
57    subparser_diff.add_argument("-o", "--output",
58                                help="file to write the unified diff file to")
59
60    subparser_format = subparsers.add_parser("format",
61                                             help=("automatically format " +
62                                                   "source files"),
63                                             parents=[parser_common])
64
65    subparser_format_output_group = \
66        subparser_format.add_mutually_exclusive_group()
67    subparser_format_output_group.add_argument("-i", "--in-place",
68                                               action="store_true",
69                                               help="format the file in-place")
70    subparser_format_output_group.add_argument("-o", "--output",
71                                               help=("file to write " +
72                                                     "output data to"))
73
74    # fmt: on
75
76    args = parser.parse_args()
77
78    colorama.init
79
80    result = os.EX_OK
81
82    for source in args.sources:
83        source = os.path.relpath(source)
84
85        with open(source, "r") as istream:
86            idata = istream.read()
87
88        data = yaml.round_trip_load(idata)
89
90        with StringIO() as odata:
91            yaml.round_trip_dump(data, odata)
92
93            odata = odata.getvalue()
94
95        if ("in_place" in args) and args.in_place:
96            ostream = open(source, "w")
97        elif ("output" in args) and args.output:
98            ostream = open(args.output, "w")
99        else:
100            ostream = sys.stdout
101
102        if args.command == "diff":
103            if idata != odata:
104                if not args.quiet:
105                    print("Style violations found: " + source, file=sys.stderr)
106
107                if args.check:
108                    result = os.EX_DATAERR
109
110                diff = difflib.unified_diff(
111                    idata.splitlines(keepends=True),
112                    odata.splitlines(keepends=True),
113                    fromfile=source + " (original)",
114                    tofile=source + " (reformatted)",
115                )
116
117                for line in colorize(diff):
118                    ostream.write(line)
119        elif args.command == "format":
120            ostream.write(odata)
121
122        if ostream != sys.stdout:
123            ostream.close()
124
125    return result
126
127
128if __name__ == "__main__":
129    sys.exit(main())
130