1 '''
2 Radix-4 DIT, Radix-4 DIF, Radix-2 DIT, Radix-2 DIF FFTs
3 John Bryan, 2017
4 Python 2.7.3
5 '''
6
7 import numpy as np
8 import time
9 import matplotlib.pyplot as plt
10 import warnings
11 np.set_printoptions(threshold = np.nan, precision = 3, suppress = 1)
12 warnings.filterwarnings("ignore")
13
14
15 def swap(xarray, i, j):
16 '''
17 swap
18 '''
19 temp = xarray[i]
20 xarray[i] = xarray[j]
21 xarray[j] = temp
22 return None
23
24
25 def digitreversal(xarray, radix, log2length, length):
26 '''
27 digitreversal
28 '''
29 if log2length % 2 == 0:
30 n1var = int(np.sqrt(length))
31 else:
32 n1var = int(np.sqrt(int(length/radix)))
33
34 reverse = np.zeros(n1var, dtype = int)
35 reverse[1] = int(length/radix)
36 for jvar in range(1, radix):
37 reverse[jvar] = reverse[jvar-1]+reverse[1]
38 for i in range(1, int(n1var/radix)):
39 reverse[radix*i] = int(reverse[i]/radix)
40 for jvar in range(1, radix):
41 reverse[int(radix*i)+jvar] = reverse[int(radix*i)]+reverse[jvar]
42
43 for i in range(0, n1var-1):
44 for jvar in range(i+1, n1var):
45 uvar = i+reverse[jvar]
46 vvar = jvar+reverse[i]
47 swap(xarray, uvar, vvar)
48 if log2length % 2 == 1:
49 for zvar in range(1, radix):
50 uvar = i+reverse[jvar]+(zvar*n1var)
51 vvar = jvar+reverse[i]+(zvar*n1var)
52 swap(xarray, uvar, vvar)
53 return xarray
54
55
56 def dif_fft4(xarray, twiddle, svar):
57 '''
58 radix-4 dif fft
59 '''
60 nvar = np.power(4, svar)
61 tss = 1
62 krange = int(float(nvar)/4.)
63 block = 1
64 base = 0
65 for wvar in range(0, svar):
66 for hvar in range(0, block):
67 for kvar in range(0, krange):
68
69 offset = int(nvar/4)
70 avar = base+kvar
71 bvar = base+kvar+offset
72 cvar = base+kvar+(2*offset)
73 dvar = base+kvar+(3*offset)
74 apc = xarray[avar]+xarray[cvar]
75 bpd = xarray[bvar]+xarray[dvar]
76 amc = xarray[avar]-xarray[cvar]
77 bmd = xarray[bvar]-xarray[dvar]
78 xarray[avar] = apc+bpd
79 if kvar == 0:
80 xarray[bvar] = amc-(1j*bmd)
81 xarray[cvar] = apc-bpd
82 xarray[dvar] = amc+(1j*bmd)
83 else:
84 r1var = twiddle[kvar*tss]
85 r2var = twiddle[2*kvar*tss]
86 r3var = twiddle[3*kvar*tss]
87 xarray[bvar] = (amc-(1j*bmd))*r1var
88 xarray[cvar] = (apc-bpd)*r2var
89 xarray[dvar] = (amc+(1j*bmd))*r3var
90 base = base+(4*krange)
91 block = block*4
92 nvar = float(nvar)/4.
93 krange = int(float(krange)/4.)
94 base = 0
95 tss = int(tss*4)
96 return xarray
97
98
99
100 def fft4(xarray, twiddles, svar):
101 '''
102 radix-4 dit fft
103 '''
104 nvar = 4
105 tss = np.power(4, svar-1)
106 krange = 1
107 block = int(xarray.size/4)
108 base = 0
109 for wvar in range(0, svar):
110 for zvar in range(0, block):
111 for kvar in range(0, krange):
112
113 offset = nvar/4
114 avar = base+kvar
115 bvar = base+kvar+offset
116 cvar = base+kvar+(2*offset)
117 dvar = base+kvar+(3*offset)
118 if kvar == 0:
119 xbr1 = xarray[bvar]
120 xcr2 = xarray[cvar]
121 xdr3 = xarray[dvar]
122 else:
123 r1var = twiddles[kvar*tss]
124 r2var = twiddles[2*kvar*tss]
125 r3var = twiddles[3*kvar*tss]
126 xbr1 = (xarray[bvar]*r1var)
127 xcr2 = (xarray[cvar]*r2var)
128 xdr3 = (xarray[dvar]*r3var)
129 evar = xarray[avar]+xcr2
130 fvar = xarray[avar]-xcr2
131 gvar = xbr1+xdr3
132 hvar = xbr1-xdr3
133 j_h = 1j*hvar
134 xarray[avar] = evar+gvar
135 xarray[bvar] = fvar-j_h
136 xarray[cvar] = -gvar+evar
137 xarray[dvar] = j_h+fvar
138 base = base+(4*krange)
139 block = block/4
140 nvar = 4*nvar
141 krange = 4*krange
142 base = 0
143 tss = float(tss)/4.
144 return xarray
145
146
147 def dif_fft0 (xarray, twiddle, log2length):
148 '''
149 radix-2 dif
150 '''
151 b_p = 1
152 nvar_p = xarray.size
153 twiddle_step_size = 1
154 for pvar in range(0, log2length):
155 nvar_pp = nvar_p/2
156 base_e = 0
157 for bvar in range(0, b_p):
158 base_o = base_e+nvar_pp
159 for nvar in range(0, nvar_pp):
160 evar = xarray[base_e+nvar]+xarray[base_o+nvar]
161 if nvar == 0:
162 ovar = xarray[base_e+nvar]-xarray[base_o+nvar]
163 else:
164 twiddle_factor = nvar*twiddle_step_size
165 ovar = (xarray[base_e+nvar] \
166 -xarray[base_o+nvar])*twiddle[twiddle_factor]
167 xarray[base_e+nvar] = evar
168 xarray[base_o+nvar] = ovar
169 base_e = base_e+nvar_p
170 b_p = b_p*2
171 nvar_p = nvar_p/2
172 twiddle_step_size = 2*twiddle_step_size
173 return xarray
174
175
176 def fft2 (xarray, twiddle, svar) :
177 '''
178 radix-2 dit
179 '''
180 nvar = xarray.size
181 b_p = nvar/2
182 nvar_p = 2
183 twiddle_step_size = nvar/2
184 for pvar in range(0, svar):
185 nvar_pp = nvar_p/2
186 base_t = 0
187 for bvar in range(0, b_p):
188 base_b = base_t+nvar_pp
189 for nvar in range(0, nvar_pp):
190 if nvar == 0:
191 bot = xarray[base_b+nvar]
192 else:
193 twiddle_factor = nvar*twiddle_step_size
194 bot = xarray[base_b+nvar]*twiddle[twiddle_factor]
195 top = xarray[base_t+nvar]
196 xarray[base_t+nvar] = top+bot
197 xarray[base_b+nvar] = top-bot
198 base_t = base_t+nvar_p
199 b_p = b_p/2
200 nvar_p = nvar_p*2
201 twiddle_step_size = twiddle_step_size/2
202 return xarray
203
204
205 def testr4dif():
206 '''
207 Test and time dif radix4 w/ multiple length random sequences
208 '''
209 flag = 0
210 i = 0
211 radix = 4
212 r4diftimes = np.zeros(6)
213 for svar in range (2, 8):
214 xarray = np.random.rand(2*np.power(4, svar)).view(np.complex128)
215 xpy = np.fft.fft(xarray)
216 radix = 4
217 nvar = np.power(4, svar)
218 kmax = 3*((float(nvar)/4.)-1)
219 k_wavenumber = np.linspace(0, kmax, kmax+1)
220 twiddlefactors = np.exp(-2j*np.pi*k_wavenumber/nvar)
221 tvar = time.time()
222 xarray = dif_fft4(xarray, twiddlefactors, svar)
223 r4diftimes[i] = time.time()-tvar
224 xarray = digitreversal(xarray, radix, svar, nvar)
225 t_f = np.allclose(xarray, xpy)
226 if t_f == 0:
227 flag = 1
228 assert(t_f)
229 i = i+1
230 if flag == 0:
231 print ("All radix-4 dif results were correct.")
232 return r4diftimes
233
234
235 def testr4():
236 '''
237 Test and time dit radix4 w/ multiple length random sequences
238 '''
239 flag = 0
240 i = 0
241 radix = 4
242 r4times = np.zeros(6)
243 for svar in range (2, 8):
244 xarray = np.random.rand(2*np.power(4, svar)).view(np.complex128)
245 xpy = np.fft.fft(xarray)
246 nvar = np.power(4, svar)
247 xarray = digitreversal(xarray, radix, svar, nvar)
248 kmax = 3*((float(nvar)/4.)-1)
249 k_wavenumber = np.linspace(0, kmax, kmax+1)
250 twiddles = np.exp(-2j*np.pi*k_wavenumber/nvar)
251 tvar = time.time()
252 xarray = fft4(xarray, twiddles, svar)
253 r4times[i] = time.time()-tvar
254 t_f = np.allclose(xarray, xpy)
255 if t_f == 0:
256 flag = 1
257 assert(t_f)
258 i = i+1
259 if flag == 0:
260 print ("All radix-4 dit results were correct.")
261 return r4times
262
263
264 def testr2dif():
265 '''
266 Test and time radix2 dif w/ multiple length random sequences
267 '''
268 flag = 0
269 i = 0
270 radix = 2
271 r2diftimes = np.zeros(6)
272 for rvar in range (2, 8):
273 svar = np.power(4, rvar)
274 cpy = np.random.rand(2*svar).view(np.complex_)
275 gpy = np.fft.fft(cpy)
276 nvar = svar
277 kmax = (float(nvar)/2.)-1
278 k_wavenumber = np.linspace(0, kmax, kmax+1)
279 twiddles = np.exp(-2j*np.pi*k_wavenumber/nvar)
280 t1time = time.time()
281 gvar = dif_fft0(cpy, twiddles, int(2*rvar))
282 r2diftimes[i] = time.time()-t1time
283 zvar = digitreversal(gvar, radix, int(2*rvar), svar)
284 t_f = np.allclose(zvar, gpy)
285 if t_f == 0:
286 flag = 1
287 assert(t_f)
288 i = i+1
289 if flag == 0:
290 print ("All radix-2 dif results were correct.")
291 return r2diftimes
292
293
294 def testr2():
295 '''
296 Test and time radix2 dit w/ multiple length random sequences
297 '''
298 radix = 2
299 flag = 0
300 i = 0
301 r2times = np.zeros(6)
302 for rvar in range (2, 8):
303 svar = np.power(4, rvar)
304 cpy = np.random.rand(2*svar).view(np.complex_)
305 gpy = np.fft.fft(cpy)
306 nvar = svar
307 kmax = (float(nvar)/2.)-1
308 k_wavenumber = np.linspace(0, kmax, kmax+1)
309 twiddles = np.exp(-2j*np.pi*k_wavenumber/nvar)
310 zvar = digitreversal(cpy, radix, int(2*rvar), svar)
311 t1time = time.time()
312 garray = fft2(zvar, twiddles, int(2*rvar))
313 r2times[i] = time.time()-t1time
314 t_f = np.allclose(garray, gpy)
315 if t_f == 0:
316 flag = 1
317 assert(t_f)
318 i = i+1
319 if flag == 0:
320 print ("All radix-2 dit results were correct.")
321 return r2times
322
323
324 def plot_times(tr2, trdif2, tr4, trdif4):
325 '''
326 plot performance
327 '''
328 uvector = np.zeros(6, dtype = int)
329 for i in range(2, 8):
330 uvector[i-2] = np.power(4, i)
331 plt.figure(figsize = (7, 5))
332 plt.rc("font", size = 9)
333 plt.loglog(uvector, trdif2, 'o', ms = 5, markerfacecolor = "None", \
334 markeredgecolor = 'red', markeredgewidth = 1, \
335 basex = 4, basey = 10, label = 'radix-2 DIF')
336 plt.loglog(uvector, tr2, '^', ms = 5, markerfacecolor = "None", \
337 markeredgecolor = 'green', markeredgewidth = 1, \
338 basex = 4, basey = 10, label = 'radix-2 DIT')
339 plt.loglog(uvector, trdif4, 'D', ms = 5, markerfacecolor = "None", \
340 markeredgecolor = 'blue', markeredgewidth = 1, \
341 basex = 4, basey = 10, label = 'radix-4 DIF')
342 plt.loglog(uvector, tr4, 's', ms = 5, markerfacecolor = "None", \
343 markeredgecolor = 'black', markeredgewidth = 1, \
344 basex = 4, basey = 10, label = 'radix-4 DIT')
345 plt.legend(loc = 2)
346 plt.grid()
347 plt.xlim([12, 18500])
348 plt.ylim([.00004, 1])
349 plt.ylabel("time (seconds)")
350 plt.xlabel("sequence length")
351 plt.title("Time vs Length")
352 plt.savefig('tvl2.png', bbox_inches = 'tight')
353 plt.show()
354 return None
355
356
357 def test():
358 '''
359 test performance
360 '''
361 trdif4 = testr4dif()
362 tr4 = testr4()
363 trdif2 = testr2dif()
364 tr2 = testr2()
365 plot_times(tr2, trdif2, tr4, trdif4)
366 return None
367
368
369 test()