1# Copyright (c) 2019 Project Nayuki. (MIT License)
2# https://www.nayuki.io/page/free-small-fft-in-multiple-languages
3
4import math, cmath
5
6
7def transform_radix2(vector, inverse):
8    # Returns the integer whose value is the reverse of the lowest 'bits' bits of the integer 'x'.
9    def reverse(x, bits):
10        y = 0
11        for i in range(bits):
12            y = (y << 1) | (x & 1)
13            x >>= 1
14        return y
15
16    # Initialization
17    n = len(vector)
18    levels = int(math.log2(n))
19    coef = (2 if inverse else -2) * cmath.pi / n
20    exptable = [cmath.rect(1, i * coef) for i in range(n // 2)]
21    vector = [vector[reverse(i, levels)] for i in range(n)]  # Copy with bit-reversed permutation
22
23    # Radix-2 decimation-in-time FFT
24    size = 2
25    while size <= n:
26        halfsize = size // 2
27        tablestep = n // size
28        for i in range(0, n, size):
29            k = 0
30            for j in range(i, i + halfsize):
31                temp = vector[j + halfsize] * exptable[k]
32                vector[j + halfsize] = vector[j] - temp
33                vector[j] += temp
34                k += tablestep
35        size *= 2
36    return vector
37
38
39###########################################################################
40# Benchmark interface
41
42bm_params = {
43    (50, 25): (2, 128),
44    (100, 100): (3, 256),
45    (1000, 1000): (20, 512),
46    (5000, 1000): (100, 512),
47}
48
49
50def bm_setup(params):
51    state = None
52    signal = [math.cos(2 * math.pi * i / params[1]) + 0j for i in range(params[1])]
53    fft = None
54    fft_inv = None
55
56    def run():
57        nonlocal fft, fft_inv
58        for _ in range(params[0]):
59            fft = transform_radix2(signal, False)
60            fft_inv = transform_radix2(fft, True)
61
62    def result():
63        nonlocal fft, fft_inv
64        fft[1] -= 0.5 * params[1]
65        fft[-1] -= 0.5 * params[1]
66        fft_ok = all(abs(f) < 1e-3 for f in fft)
67        for i in range(len(fft_inv)):
68            fft_inv[i] -= params[1] * signal[i]
69        fft_inv_ok = all(abs(f) < 1e-3 for f in fft_inv)
70        return params[0] * params[1], (fft_ok, fft_inv_ok)
71
72    return run, result
73