1# Copyright 2015 The BoringSSL Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Extracts archives."""
16
17
18import hashlib
19import optparse
20import os
21import os.path
22import tarfile
23import shutil
24import sys
25import zipfile
26
27
28def CheckedJoin(output, path):
29  """
30  CheckedJoin returns os.path.join(output, path). It does sanity checks to
31  ensure the resulting path is under output, but shouldn't be used on untrusted
32  input.
33  """
34  path = os.path.normpath(path)
35  if os.path.isabs(path) or path.startswith('.'):
36    raise ValueError(path)
37  return os.path.join(output, path)
38
39
40class FileEntry(object):
41  def __init__(self, path, mode, fileobj):
42    self.path = path
43    self.mode = mode
44    self.fileobj = fileobj
45
46
47class SymlinkEntry(object):
48  def __init__(self, path, mode, target):
49    self.path = path
50    self.mode = mode
51    self.target = target
52
53
54def IterateZip(path):
55  """
56  IterateZip opens the zip file at path and returns a generator of entry objects
57  for each file in it.
58  """
59  with zipfile.ZipFile(path, 'r') as zip_file:
60    for info in zip_file.infolist():
61      if info.filename.endswith('/'):
62        continue
63      yield FileEntry(info.filename, None, zip_file.open(info))
64
65
66def IterateTar(path, compression):
67  """
68  IterateTar opens the tar.gz or tar.bz2 file at path and returns a generator of
69  entry objects for each file in it.
70  """
71  with tarfile.open(path, 'r:' + compression) as tar_file:
72    for info in tar_file:
73      if info.isdir():
74        pass
75      elif info.issym():
76        yield SymlinkEntry(info.name, None, info.linkname)
77      elif info.isfile():
78        yield FileEntry(info.name, info.mode, tar_file.extractfile(info))
79      else:
80        raise ValueError('Unknown entry type "%s"' % (info.name, ))
81
82
83def main(args):
84  parser = optparse.OptionParser(usage='Usage: %prog ARCHIVE OUTPUT')
85  parser.add_option('--no-prefix', dest='no_prefix', action='store_true',
86                    help='Do not remove a prefix from paths in the archive.')
87  options, args = parser.parse_args(args)
88
89  if len(args) != 2:
90    parser.print_help()
91    return 1
92
93  archive, output = args
94
95  if not os.path.exists(archive):
96    # Skip archives that weren't downloaded.
97    return 0
98
99  with open(archive, 'rb') as f:
100    sha256 = hashlib.sha256()
101    while True:
102      chunk = f.read(1024 * 1024)
103      if not chunk:
104        break
105      sha256.update(chunk)
106    digest = sha256.hexdigest()
107
108  stamp_path = os.path.join(output, ".boringssl_archive_digest")
109  if os.path.exists(stamp_path):
110    with open(stamp_path) as f:
111      if f.read().strip() == digest:
112        print("Already up-to-date.")
113        return 0
114
115  if archive.endswith('.zip'):
116    entries = IterateZip(archive)
117  elif archive.endswith('.tar.gz'):
118    entries = IterateTar(archive, 'gz')
119  elif archive.endswith('.tar.bz2'):
120    entries = IterateTar(archive, 'bz2')
121  elif archive.endswith('.tar.xz'):
122    entries = IterateTar(archive, 'xz')
123  else:
124    raise ValueError(archive)
125
126  try:
127    if os.path.exists(output):
128      print("Removing %s" % (output, ))
129      shutil.rmtree(output)
130
131    print("Extracting %s to %s" % (archive, output))
132    prefix = None
133    num_extracted = 0
134    for entry in entries:
135      # Even on Windows, zip files must always use forward slashes.
136      if '\\' in entry.path or entry.path.startswith('/'):
137        raise ValueError(entry.path)
138
139      if not options.no_prefix:
140        new_prefix, rest = entry.path.split('/', 1)
141
142        # Ensure the archive is consistent.
143        if prefix is None:
144          prefix = new_prefix
145        if prefix != new_prefix:
146          raise ValueError((prefix, new_prefix))
147      else:
148        rest = entry.path
149
150      # Extract the file into the output directory.
151      fixed_path = CheckedJoin(output, rest)
152      if not os.path.isdir(os.path.dirname(fixed_path)):
153        os.makedirs(os.path.dirname(fixed_path))
154      if isinstance(entry, FileEntry):
155        with open(fixed_path, 'wb') as out:
156          shutil.copyfileobj(entry.fileobj, out)
157      elif isinstance(entry, SymlinkEntry):
158        os.symlink(entry.target, fixed_path)
159      else:
160        raise TypeError('unknown entry type')
161
162      # Fix up permissions if needbe.
163      # TODO(davidben): To be extra tidy, this should only track the execute bit
164      # as in git.
165      if entry.mode is not None:
166        os.chmod(fixed_path, entry.mode)
167
168      # Print every 100 files, so bots do not time out on large archives.
169      num_extracted += 1
170      if num_extracted % 100 == 0:
171        print("Extracted %d files..." % (num_extracted,))
172  finally:
173    entries.close()
174
175  with open(stamp_path, 'w') as f:
176    f.write(digest)
177
178  print("Done. Extracted %d files." % (num_extracted,))
179  return 0
180
181
182if __name__ == '__main__':
183  sys.exit(main(sys.argv[1:]))
184