1#!/usr/bin/env python3
2#
3# Arm SCP/MCP Software
4# Copyright (c) 2015-2024, Arm Limited and Contributors. All rights reserved.
5#
6# SPDX-License-Identifier: BSD-3-Clause
7#
8
9"""
10    Check for trailing spaces and non-UNIX line endings in the source code.
11"""
12
13import argparse
14import os
15import re
16import sys
17import fnmatch
18import glob
19from dataclasses import dataclass, asdict
20from utils import banner, get_filtered_files
21
22
23#
24# Directories to exclude
25#
26
27# Exclude all mod_test 'mocks' directories
28UNIT_TEST_MOCKS = glob.glob('module/**/test/**/mocks', recursive=True) +\
29                  glob.glob('product/**/test/**/mocks', recursive=True)
30
31EXCLUDE_DIRECTORIES = [
32    '.git',
33    'build',
34    'tools',
35    'contrib',
36    'product/rcar/src/CMSIS-FreeRTOS',
37    'unit_test/unity_mocks',
38] + UNIT_TEST_MOCKS
39
40#
41# File types to check
42#
43FILE_TYPES = [
44    '*.c',
45    '*.h',
46    '*.cmake',
47    'CMakeLists.txt',
48    '*.yml',
49    '*.yaml',
50    'Dockerfile',
51    '*.sh',
52    '*.mk',
53]
54
55#
56# Code file types to check
57#
58CODE_FILES = [
59    '*.c',
60    '*.h',
61]
62
63#
64# Code keywords to check
65#
66KEYWORDS = [
67    'for',
68    'if',
69    'switch',
70    'while',
71]
72
73
74@dataclass
75class Analysis:
76    trailing_spaces: int = 0
77    trailing_lines: int = 0
78    incorrect_spaces: int = 0
79    modified_files: int = 0
80    non_unix_eol_files: int = 0
81    missing_new_line_files: int = 0
82
83    def has_errors(self) -> bool:
84        return any([value for value in asdict(self).values()])
85
86    def add(self, partial_analysis):
87        for elem, count in asdict(partial_analysis).items():
88            if hasattr(self, elem):
89                setattr(self, elem, getattr(self, elem) + count)
90            else:
91                # Handle the error or ignore
92                print(f'Warning: Attribute {elem} not found.')
93        return self
94
95    def __str__(self) -> str:
96        msg = ''
97        msg += self._str_message('trailing spaces', self.trailing_spaces)
98        msg += self._str_message('trailing lines', self.trailing_lines)
99        msg += self._str_message('abnormal spaces', self.incorrect_spaces)
100        msg += self._str_message('files with non-UNIX or mixed line endings',
101                                 self.non_unix_eol_files)
102        msg += self._str_message('files with missing newlines at EOF',
103                                 self.missing_new_line_files)
104        if self.modified_files:
105            msg += f'- {self.modified_files} files modified.\n'
106        return msg
107
108    def _str_message(self, name, count) -> str:
109        if count == 0:
110            return f'- No {name} found.\n'
111        else:
112            return f'- {count} {name} found.\n'
113
114
115def is_valid_code_type(filename):
116    return any(fnmatch.fnmatch(filename, file_type)
117               for file_type in CODE_FILES)
118
119
120def get_regex_patterns(keywords):
121    regex_patterns = dict.fromkeys(keywords, 0)
122    for keyword in keywords:
123        regex_patterns[keyword] = re.compile(
124            f'(.*\\W)({keyword})(\\s*)(\\(.*)')
125
126    return regex_patterns
127
128
129def check_line(path, line, regex_patterns, analysis, trim, correct):
130    # Note that all newlines are converted to '\n',
131    # so the following will work regardless of
132    # what the underlying file format is using to
133    # represent a line break.
134    if line.endswith(' \n'):
135        print('{}:{} has trailing space'.format(line, path))
136        analysis.trailing_spaces += 1
137        if trim:
138            line = line.rstrip()+'\n'
139    if not is_valid_code_type(os.path.basename(path)):
140        return line
141
142    for keyword, regex_pattern in regex_patterns.items():
143        if line.find(keyword) < 0:
144            continue
145        m = regex_pattern.search(line)
146        if m and m.group(3) != ' ':
147            analysis.incorrect_spaces += 1
148            print(f'Abnormal spacing. \'{keyword}\', {path}:{line} \
149                --> {line.rstrip()}')
150            if correct:
151                line = m.group(1) + m.group(2) + ' ' + \
152                            m.group(4) + '\n'
153
154    return line
155
156
157def write_file(path, analysis, modified_file, trim, correct):
158    #
159    # Trim and/or correct file, depending on the provided arguments
160    #
161    write_file = False
162    if trim and (analysis.trailing_spaces
163                 or analysis.trailing_lines) != 0:
164        print('Trimming {}...'.format(path))
165        write_file = True
166    if correct and analysis.incorrect_spaces != 0:
167        print('Correcting {}...'.format(path))
168        write_file = True
169    if write_file:
170        analysis.modified_files += 1
171        with open(path, 'w') as file:
172            file.write(modified_file)
173
174
175def check_files(file_paths, regex_patterns, trim, correct):
176    analysis = Analysis()
177    for path in file_paths:
178        partial_analysis = Analysis()
179        content = ''
180        try:
181            with open(path, encoding='utf-8') as file:
182                lines = file.readlines()
183                if lines and not lines[-1].endswith('\n'):
184                    partial_analysis.missing_new_line_files += 1
185                    print(f'{path}: is missing a new line at the end of file')
186
187                for line in lines:
188                    processed_line = check_line(path,
189                                                line,
190                                                regex_patterns,
191                                                partial_analysis,
192                                                trim,
193                                                correct)
194                    content += processed_line
195
196        except UnicodeDecodeError:
197            print(f'Invalid file format {path}')
198
199        if content.endswith('\n\n'):
200            print(f'Blank line at the end of file --> {path}')
201            if trim:
202                content = content.rstrip() + '\n'
203            trailing_lines = len(content) - len(content.rstrip() + '\n')
204            partial_analysis.trailing_lines += trailing_lines
205
206        #
207        # If file.newlines has been set it is either a string with
208        # the determined line ending or a tuple with all the line
209        # endings we have encountered
210        #
211        if file.newlines:
212            if isinstance(file.newlines, tuple):
213                print('{} has mixed line endings'.format(path))
214                partial_analysis.non_unix_eol_files += 1
215            elif file.newlines != '\n':
216                print('{} has non-UNIX line endings'.format(path))
217                partial_analysis.non_unix_eol_files += 1
218
219        write_file(path, partial_analysis, content, trim, correct)
220
221        analysis.add(partial_analysis)
222
223    return analysis
224
225
226def run(trim=False, correct=False):
227    print(banner('Checking for incorrect spacing in the code...'))
228
229    if trim:
230        print('Trim mode is enabled.')
231    if correct:
232        print('Correct mode is enabled.')
233
234    regex_patterns = get_regex_patterns(KEYWORDS)
235    files = get_filtered_files(EXCLUDE_DIRECTORIES, FILE_TYPES)
236    analysis = check_files(files, regex_patterns, trim, correct)
237    print(analysis)
238
239    return not analysis.has_errors()
240
241
242def parse_args(argv, prog_name):
243    parser = argparse.ArgumentParser(
244        prog=prog_name,
245        description='Perform checks for incorrect spacing in the code')
246
247    parser.add_argument('-t', '--trim', dest='trim',
248                        required=False, default=False, action='store_true',
249                        help='Remove trailing spaces.')
250
251    parser.add_argument('-c', '--correct', dest='correct',
252                        required=False, default=False, action='store_true',
253                        help='Correct spaces after keywords.')
254
255    return parser.parse_args(argv)
256
257
258def main(argv=[], prog_name=''):
259    args = parse_args(argv, prog_name)
260    return 0 if run(args.trim, args.correct) else 1
261
262
263if __name__ == '__main__':
264    sys.exit(main(sys.argv[1:], sys.argv[0]))
265