1#!/usr/bin/env python3
2#
3# This file is part of the MicroPython project, http://micropython.org/
4#
5# The MIT License (MIT)
6#
7# Copyright (c) 2020 Damien P. George
8# Copyright (c) 2020 Jim Mussared
9#
10# Permission is hereby granted, free of charge, to any person obtaining a copy
11# of this software and associated documentation files (the "Software"), to deal
12# in the Software without restriction, including without limitation the rights
13# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
14# copies of the Software, and to permit persons to whom the Software is
15# furnished to do so, subject to the following conditions:
16#
17# The above copyright notice and this permission notice shall be included in
18# all copies or substantial portions of the Software.
19#
20# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
21# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
22# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
23# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
24# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
25# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
26# THE SOFTWARE.
27
28import argparse
29import glob
30import itertools
31import os
32import re
33import subprocess
34
35# Relative to top-level repo dir.
36PATHS = [
37    # C
38    "extmod/*.[ch]",
39    "extmod/btstack/*.[ch]",
40    "extmod/nimble/*.[ch]",
41    "lib/mbedtls_errors/tester.c",
42    "shared/netutils/*.[ch]",
43    "shared/timeutils/*.[ch]",
44    "shared/runtime/*.[ch]",
45    "mpy-cross/*.[ch]",
46    "ports/*/*.[ch]",
47    "ports/windows/msvc/**/*.[ch]",
48    "ports/nrf/modules/nrf/*.[ch]",
49    "py/*.[ch]",
50    # Python
51    "drivers/**/*.py",
52    "examples/**/*.py",
53    "extmod/**/*.py",
54    "ports/**/*.py",
55    "py/**/*.py",
56    "tools/**/*.py",
57    "tests/**/*.py",
58]
59
60EXCLUSIONS = [
61    # STM32 build includes generated Python code.
62    "ports/*/build*",
63    # gitignore in ports/unix ignores *.py, so also do it here.
64    "ports/unix/*.py",
65    # not real python files
66    "tests/**/repl_*.py",
67    # needs careful attention before applying automatic formatting
68    "tests/basics/*.py",
69]
70
71# Path to repo top-level dir.
72TOP = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
73
74UNCRUSTIFY_CFG = os.path.join(TOP, "tools/uncrustify.cfg")
75
76C_EXTS = (
77    ".c",
78    ".h",
79)
80PY_EXTS = (".py",)
81
82
83def list_files(paths, exclusions=None, prefix=""):
84    files = set()
85    for pattern in paths:
86        files.update(glob.glob(os.path.join(prefix, pattern), recursive=True))
87    for pattern in exclusions or []:
88        files.difference_update(glob.fnmatch.filter(files, os.path.join(prefix, pattern)))
89    return sorted(files)
90
91
92def fixup_c(filename):
93    # Read file.
94    with open(filename) as f:
95        lines = f.readlines()
96
97    # Write out file with fixups.
98    with open(filename, "w", newline="") as f:
99        dedent_stack = []
100        while lines:
101            # Get next line.
102            l = lines.pop(0)
103
104            # Dedent #'s to match indent of following line (not previous line).
105            m = re.match(r"( +)#(if |ifdef |ifndef |elif |else|endif)", l)
106            if m:
107                indent = len(m.group(1))
108                directive = m.group(2)
109                if directive in ("if ", "ifdef ", "ifndef "):
110                    l_next = lines[0]
111                    indent_next = len(re.match(r"( *)", l_next).group(1))
112                    if indent - 4 == indent_next and re.match(r" +(} else |case )", l_next):
113                        # This #-line (and all associated ones) needs dedenting by 4 spaces.
114                        l = l[4:]
115                        dedent_stack.append(indent - 4)
116                    else:
117                        # This #-line does not need dedenting.
118                        dedent_stack.append(-1)
119                else:
120                    if dedent_stack[-1] >= 0:
121                        # This associated #-line needs dedenting to match the #if.
122                        indent_diff = indent - dedent_stack[-1]
123                        assert indent_diff >= 0
124                        l = l[indent_diff:]
125                    if directive == "endif":
126                        dedent_stack.pop()
127
128            # Write out line.
129            f.write(l)
130
131        assert not dedent_stack, filename
132
133
134def main():
135    cmd_parser = argparse.ArgumentParser(description="Auto-format C and Python files.")
136    cmd_parser.add_argument("-c", action="store_true", help="Format C code only")
137    cmd_parser.add_argument("-p", action="store_true", help="Format Python code only")
138    cmd_parser.add_argument("-v", action="store_true", help="Enable verbose output")
139    cmd_parser.add_argument("files", nargs="*", help="Run on specific globs")
140    args = cmd_parser.parse_args()
141
142    # Setting only one of -c or -p disables the other. If both or neither are set, then do both.
143    format_c = args.c or not args.p
144    format_py = args.p or not args.c
145
146    # Expand the globs passed on the command line, or use the default globs above.
147    files = []
148    if args.files:
149        files = list_files(args.files)
150    else:
151        files = list_files(PATHS, EXCLUSIONS, TOP)
152
153    # Extract files matching a specific language.
154    def lang_files(exts):
155        for file in files:
156            if os.path.splitext(file)[1].lower() in exts:
157                yield file
158
159    # Run tool on N files at a time (to avoid making the command line too long).
160    def batch(cmd, files, N=200):
161        while True:
162            file_args = list(itertools.islice(files, N))
163            if not file_args:
164                break
165            subprocess.check_call(cmd + file_args)
166
167    # Format C files with uncrustify.
168    if format_c:
169        command = ["uncrustify", "-c", UNCRUSTIFY_CFG, "-lC", "--no-backup"]
170        if not args.v:
171            command.append("-q")
172        batch(command, lang_files(C_EXTS))
173        for file in lang_files(C_EXTS):
174            fixup_c(file)
175
176    # Format Python files with black.
177    if format_py:
178        command = ["black", "--fast", "--line-length=99"]
179        if args.v:
180            command.append("-v")
181        else:
182            command.append("-q")
183        batch(command, lang_files(PY_EXTS))
184
185
186if __name__ == "__main__":
187    main()
188