1 '''
  2 Radix-4 DIT,  Radix-4 DIF,  Radix-2 DIT,  Radix-2 DIF FFTs
  3 John Bryan,  2017
  4 Python 2.7.3
  5  '''
  6 
  7 import numpy as np
  8 import time
  9 import matplotlib.pyplot as plt
 10 import warnings
 11 np.set_printoptions(threshold = np.nan, precision = 3, suppress = 1)
 12 warnings.filterwarnings("ignore")
 13 
 14 
 15 def swap(xarray, i, j):
 16     '''
 17     swap
 18     '''
 19     temp = xarray[i]
 20     xarray[i] = xarray[j]
 21     xarray[j] = temp
 22     return None
 23 
 24 
 25 def digitreversal(xarray, radix, log2length, length):
 26     '''
 27     digitreversal
 28     '''
 29     if log2length % 2 == 0:
 30         n1var = int(np.sqrt(length))       #seed table size 
 31     else:
 32         n1var = int(np.sqrt(int(length/radix)))
 33     # algorithm 2,  compute seed table  
 34     reverse = np.zeros(n1var, dtype = int)
 35     reverse[1] = int(length/radix)
 36     for jvar in range(1, radix):
 37         reverse[jvar] = reverse[jvar-1]+reverse[1]
 38         for i in range(1, int(n1var/radix)):
 39             reverse[radix*i] = int(reverse[i]/radix)
 40             for jvar in range(1, radix):
 41                 reverse[int(radix*i)+jvar] = reverse[int(radix*i)]+reverse[jvar]
 42     #algorithm 1
 43     for i in range(0, n1var-1):
 44         for jvar in range(i+1, n1var):
 45             uvar = i+reverse[jvar]
 46             vvar = jvar+reverse[i]
 47             swap(xarray, uvar, vvar)
 48             if log2length % 2 == 1:
 49                 for zvar in range(1, radix):
 50                     uvar = i+reverse[jvar]+(zvar*n1var)
 51                     vvar = jvar+reverse[i]+(zvar*n1var)
 52                     swap(xarray, uvar, vvar)
 53     return xarray
 54 
 55 
 56 def dif_fft4(xarray, twiddle, svar):
 57     '''
 58     radix-4 dif fft
 59     '''
 60     nvar = np.power(4, svar)
 61     tss = 1
 62     krange = int(float(nvar)/4.)
 63     block = 1
 64     base = 0
 65     for wvar in range(0, svar):
 66         for hvar in range(0, block):
 67             for kvar in range(0, krange):
 68                 # butterfly
 69                 offset = int(nvar/4)
 70                 avar = base+kvar
 71                 bvar = base+kvar+offset
 72                 cvar = base+kvar+(2*offset)
 73                 dvar = base+kvar+(3*offset)
 74                 apc = xarray[avar]+xarray[cvar]
 75                 bpd = xarray[bvar]+xarray[dvar]
 76                 amc = xarray[avar]-xarray[cvar]
 77                 bmd = xarray[bvar]-xarray[dvar]
 78                 xarray[avar] = apc+bpd
 79                 if kvar == 0:
 80                     xarray[bvar] = amc-(1j*bmd)
 81                     xarray[cvar] = apc-bpd
 82                     xarray[dvar] = amc+(1j*bmd)
 83                 else:
 84                     r1var = twiddle[kvar*tss]
 85                     r2var = twiddle[2*kvar*tss]
 86                     r3var = twiddle[3*kvar*tss]
 87                     xarray[bvar] = (amc-(1j*bmd))*r1var
 88                     xarray[cvar] = (apc-bpd)*r2var
 89                     xarray[dvar] = (amc+(1j*bmd))*r3var
 90             base = base+(4*krange)
 91         block = block*4
 92         nvar = float(nvar)/4.
 93         krange = int(float(krange)/4.)
 94         base = 0
 95         tss = int(tss*4)
 96     return xarray
 97 
 98 
 99 
100 def fft4(xarray, twiddles, svar):
101     '''
102     radix-4 dit fft
103     '''
104     nvar = 4
105     tss = np.power(4, svar-1)
106     krange = 1
107     block = int(xarray.size/4)
108     base = 0
109     for wvar in range(0, svar):
110         for zvar in range(0, block):
111             for kvar in range(0, krange):
112                 # butterfly
113                 offset = nvar/4
114                 avar = base+kvar
115                 bvar = base+kvar+offset
116                 cvar = base+kvar+(2*offset)
117                 dvar = base+kvar+(3*offset)
118                 if kvar == 0:
119                     xbr1 = xarray[bvar]
120                     xcr2 = xarray[cvar]
121                     xdr3 = xarray[dvar]
122                 else:
123                     r1var = twiddles[kvar*tss]
124                     r2var = twiddles[2*kvar*tss]
125                     r3var = twiddles[3*kvar*tss]
126                     xbr1 = (xarray[bvar]*r1var)
127                     xcr2 = (xarray[cvar]*r2var)
128                     xdr3 = (xarray[dvar]*r3var)
129                 evar = xarray[avar]+xcr2
130                 fvar = xarray[avar]-xcr2
131                 gvar = xbr1+xdr3
132                 hvar = xbr1-xdr3
133                 j_h = 1j*hvar
134                 xarray[avar] = evar+gvar
135                 xarray[bvar] = fvar-j_h
136                 xarray[cvar] = -gvar+evar
137                 xarray[dvar] = j_h+fvar
138             base = base+(4*krange)
139         block = block/4
140         nvar = 4*nvar
141         krange = 4*krange
142         base = 0
143         tss = float(tss)/4.
144     return xarray
145 
146 
147 def dif_fft0 (xarray, twiddle, log2length):
148     '''
149     radix-2 dif  
150     '''
151     b_p = 1
152     nvar_p = xarray.size
153     twiddle_step_size = 1
154     for pvar in range(0,  log2length):           # pass loop
155         nvar_pp =  nvar_p/2
156         base_e = 0
157         for bvar in range(0,  b_p):       # block loop
158             base_o = base_e+nvar_pp
159             for nvar in range(0,  nvar_pp):   # butterfly loop
160                 evar =  xarray[base_e+nvar]+xarray[base_o+nvar]
161                 if nvar == 0:
162                     ovar = xarray[base_e+nvar]-xarray[base_o+nvar]
163                 else:
164                     twiddle_factor =  nvar*twiddle_step_size
165                     ovar = (xarray[base_e+nvar] \
166                         -xarray[base_o+nvar])*twiddle[twiddle_factor]
167                 xarray[base_e+nvar] = evar
168                 xarray[base_o+nvar] = ovar
169             base_e = base_e+nvar_p
170         b_p = b_p*2
171         nvar_p = nvar_p/2
172         twiddle_step_size = 2*twiddle_step_size
173     return xarray
174 
175 
176 def fft2 (xarray, twiddle, svar) :
177     '''
178     radix-2 dit
179     '''
180     nvar = xarray.size
181     b_p = nvar/2
182     nvar_p = 2
183     twiddle_step_size = nvar/2
184     for pvar in range(0,  svar):
185         nvar_pp =  nvar_p/2
186         base_t = 0
187         for bvar in range(0,  b_p):
188             base_b = base_t+nvar_pp
189             for nvar in range(0,  nvar_pp):
190                 if nvar == 0:
191                     bot = xarray[base_b+nvar]
192                 else:
193                     twiddle_factor = nvar*twiddle_step_size
194                     bot = xarray[base_b+nvar]*twiddle[twiddle_factor]
195                 top = xarray[base_t+nvar]
196                 xarray[base_t+nvar] = top+bot
197                 xarray[base_b+nvar] = top-bot
198             base_t = base_t+nvar_p
199         b_p = b_p/2
200         nvar_p = nvar_p*2
201         twiddle_step_size = twiddle_step_size/2
202     return xarray
203 
204 
205 def testr4dif():
206     '''
207     Test and time dif radix4 w/ multiple length random sequences
208     '''
209     flag = 0
210     i = 0
211     radix = 4
212     r4diftimes = np.zeros(6)
213     for svar in range (2, 8):
214         xarray = np.random.rand(2*np.power(4, svar)).view(np.complex128)
215         xpy = np.fft.fft(xarray)
216         radix = 4
217         nvar = np.power(4, svar)
218         kmax = 3*((float(nvar)/4.)-1)
219         k_wavenumber = np.linspace(0, kmax, kmax+1)
220         twiddlefactors = np.exp(-2j*np.pi*k_wavenumber/nvar)
221         tvar = time.time()
222         xarray = dif_fft4(xarray, twiddlefactors, svar)
223         r4diftimes[i] =  time.time()-tvar
224         xarray = digitreversal(xarray, radix, svar, nvar)
225         t_f = np.allclose(xarray, xpy)
226         if t_f == 0:
227             flag = 1
228         assert(t_f)
229         i = i+1
230     if flag == 0:
231         print ("All radix-4 dif results were correct.")
232     return r4diftimes
233 
234 
235 def testr4():
236     '''
237     Test and time dit radix4 w/ multiple length random sequences
238     '''
239     flag = 0
240     i = 0
241     radix = 4
242     r4times = np.zeros(6)
243     for svar in range (2, 8):
244         xarray = np.random.rand(2*np.power(4, svar)).view(np.complex128)
245         xpy = np.fft.fft(xarray)
246         nvar = np.power(4, svar)
247         xarray = digitreversal(xarray, radix, svar, nvar)
248         kmax = 3*((float(nvar)/4.)-1)
249         k_wavenumber = np.linspace(0, kmax, kmax+1)
250         twiddles = np.exp(-2j*np.pi*k_wavenumber/nvar)
251         tvar = time.time()
252         xarray = fft4(xarray, twiddles, svar)
253         r4times[i] =  time.time()-tvar
254         t_f = np.allclose(xarray, xpy)
255         if t_f == 0:
256             flag = 1
257         assert(t_f)
258         i = i+1
259     if flag == 0:
260         print ("All radix-4 dit results were correct.")
261     return r4times
262 
263 
264 def testr2dif():
265     '''
266     Test and time radix2 dif w/ multiple length random sequences
267     '''
268     flag = 0
269     i = 0
270     radix = 2
271     r2diftimes = np.zeros(6)
272     for rvar in range (2, 8):
273         svar = np.power(4, rvar)
274         cpy = np.random.rand(2*svar).view(np.complex_)
275         gpy = np.fft.fft(cpy)
276         nvar = svar
277         kmax = (float(nvar)/2.)-1
278         k_wavenumber = np.linspace(0, kmax, kmax+1)
279         twiddles = np.exp(-2j*np.pi*k_wavenumber/nvar)
280         t1time = time.time()
281         gvar = dif_fft0(cpy, twiddles, int(2*rvar))
282         r2diftimes[i] =  time.time()-t1time
283         zvar = digitreversal(gvar, radix, int(2*rvar), svar)
284         t_f = np.allclose(zvar, gpy)
285         if t_f == 0:
286             flag = 1
287         assert(t_f)
288         i = i+1
289     if flag == 0:
290         print ("All radix-2 dif results were correct.")
291     return r2diftimes
292 
293 
294 def testr2():
295     '''
296     Test and time radix2 dit w/ multiple length random sequences
297     '''
298     radix = 2
299     flag = 0
300     i = 0
301     r2times = np.zeros(6)
302     for rvar in range (2, 8):
303         svar = np.power(4, rvar)
304         cpy = np.random.rand(2*svar).view(np.complex_)
305         gpy = np.fft.fft(cpy)
306         nvar = svar
307         kmax = (float(nvar)/2.)-1
308         k_wavenumber = np.linspace(0, kmax, kmax+1)
309         twiddles = np.exp(-2j*np.pi*k_wavenumber/nvar)
310         zvar = digitreversal(cpy, radix, int(2*rvar), svar)
311         t1time = time.time()
312         garray = fft2(zvar, twiddles, int(2*rvar))
313         r2times[i] =  time.time()-t1time
314         t_f = np.allclose(garray, gpy)
315         if t_f == 0:
316             flag = 1
317         assert(t_f)
318         i = i+1
319     if flag == 0:
320         print ("All radix-2 dit results were correct.")
321     return r2times
322 
323 
324 def plot_times(tr2, trdif2, tr4, trdif4):
325     '''
326     plot performance
327     '''
328     uvector = np.zeros(6, dtype = int)
329     for i in range(2, 8):
330         uvector[i-2] = np.power(4, i)
331     plt.figure(figsize = (7, 5))
332     plt.rc("font", size = 9)
333     plt.loglog(uvector, trdif2, 'o',  ms = 5,  markerfacecolor = "None",  \
334                markeredgecolor = 'red',  markeredgewidth = 1,  \
335                basex = 4,  basey = 10,  label = 'radix-2 DIF')
336     plt.loglog(uvector, tr2, '^',  ms = 5,  markerfacecolor = "None", \
337                markeredgecolor = 'green',  markeredgewidth = 1,  \
338                basex = 4,  basey = 10,  label = 'radix-2 DIT')
339     plt.loglog(uvector, trdif4, 'D',  ms = 5,  markerfacecolor = "None", \
340                markeredgecolor = 'blue',  markeredgewidth = 1, \
341                 basex = 4,  basey = 10,  label = 'radix-4 DIF')
342     plt.loglog(uvector, tr4, 's',  ms = 5,  markerfacecolor = "None", \
343                markeredgecolor = 'black',  markeredgewidth = 1, \
344                basex = 4,  basey = 10,  label = 'radix-4 DIT')
345     plt.legend(loc = 2)
346     plt.grid()
347     plt.xlim([12, 18500])
348     plt.ylim([.00004, 1])
349     plt.ylabel("time (seconds)")
350     plt.xlabel("sequence length")
351     plt.title("Time vs Length")
352     plt.savefig('tvl2.png',  bbox_inches = 'tight')
353     plt.show()
354     return None
355 
356 
357 def test():
358     '''
359     test performance
360     '''
361     trdif4 = testr4dif()
362     tr4 = testr4()
363     trdif2 = testr2dif()
364     tr2 = testr2()
365     plot_times(tr2, trdif2, tr4, trdif4)
366     return None
367 
368 
369 test()