任意长度的 FFT 算法-混合基算法及雷德算法

之前我写的 FFT 用的是最基本的 库利-图基算法,而且是取模2,因此只能实现 2 的幂次方长度的运算。那么如果点数不是 2 的幂次方该怎么办呢?当点数为合数的时候,可以用库利-图基混合基算法;当点数为质数的时候,可以用雷德算法

库利-图基混合基算法

公式推导

当变换点数是合数的时候,则可以使用混合基算法。因为 \(N\) 是合数,我们有 \(N = N_1N_2\)。注意到 DFT 的公式为

\[X_k = \sum_{n=0}^{N-1}x_ne^{-\frac{2 \pi i}{N}nk}\]

我们可以用 \(k = N_2k_1 + k_2,k_1 \in [0, N_1-1],k_2 \in [0, N_2-1]\) 以及 \(n = N_1n_2+n_1,n_1 \in [0, N_1-1],n_2 \in [0, N_2-1]\) 带入,得到

\(\begin{split}X_{N_2k_1+k_2}&=\sum_{n_1=0}^{N_1-1}\sum_{n_2=0}^{N_2-1}x_{N_1n_2+n_1}e^{- \frac{2 \pi i}{N_1N_2}(N_1n_2+n_1)(N_2k_1+k_2)}\\\ &=\sum_{n_1=0}^{N_1-1} \left( e^{- \frac{2 \pi i}Nn_1k_2} \right) \left(\sum_{n_2=0}^{N_2-1}x_{N_1n_2+n_1}e^{- \frac{2 \pi i}{N_2}n_2k_2} \right)e^{-\frac{2 \pi i}{N_1}n_1k_1} \end{split}\)

注意到这里其实有两个 DFT 变换,一个是 \(N_1\) 个内部长度为 \(N_2\) 的变换,另外一个是 \(N_2\) 个外部长度为 \(N_1\) 的变换,其旋转因子为 \(e^{- \frac{2 \pi i}Nn_1k_2}\)

我们注意到 \(\sum_{n_2=0}^{N_2-1}x_{N_1n_2+n_1}e^{- \frac{2 \pi i}{N_2}n_2k_2}\) 实际上是 \(N_1\) 个 长度为 \(N_2\) 的 DFT 变换。每一组是 \(x_{N_1n_2+n_1},n_1 \in [0, N_1-1]\)。再乘以旋转因子 \(e^{- \frac{2 \pi i}Nn_1k_2}\)。然后再做 \(N_2\) 个 长度为 \(N_1\) 的 DFT 变换。

具体实现

具体做法是将原始信号重新排列再进行计算。步骤如下

  1. 将信号按列存储为一个矩阵
  2. 对每行计算 \(N_2\) 点 DFT
  3. 将矩阵每一项乘以旋转因子 \(e^{- \frac{2 \pi i}Nn_1n_2}\)
  4. 对每列计算 \(N_1\) 点 DFT
  5. 将结果数组按行读出

将原始信号按列排列为 \(N_1\)\(N_2\) 列的数组,记为 \(x[n_1][n_2] = x[N_1*n_2+n_1]\)。对每一行进行 DFT 变换,结果记为 \(X_1[n_1][n_2]\)。再对 \(X_1\) 乘以 \(e^{- \frac{2 \pi i}Nn_1k_2}\),即 \(X_2[n_1][n_2] = X_1e^{- \frac{2 \pi i}Nn_1n_2}\)。最后对 \(X_2\) 的每一列进行 DFT 变换得到 \(X_3\),那么最终的结果就是 \(X[N_2k_1+k_2] = X_3[k_1][k_2]\)

雷德算法

公式推导

雷德算法适用于变换点数为质数的情况,我们略过严谨的数学证明,只看步骤。如果 N 是一个质数,定义 \(a,b\) 之间运算 · 为 $ (a*b) mod N $。则 $ [0, N-1] $ 连同运算 · 可以构成一个,该群也叫做整数模n乘法群

则根据数论,这样的群有群的生成集合,存在一个整数 \(g\) (这个整数也叫做原根,我们将在下文叙述如何寻找)使得对于任意非零的 \(n\) 对应唯一一个 \(q \in [0, N-2]\) 使得 \(n = g^q (mod\ N)\)。其实就是一个 \(q \in [0, N-2]\) 和非零 \(n\) 之间的双射。同样有非零整数 \(k\) 和一个 \(p \in [0, N-2]\) 使得 \(k = g^{-p}(mod\ N)\),也是双射。这里的负指数,指的是模逆元。则我们可以将 DFT 用上述的 \(p\)\(q\) 将原本的 DFT

\[X_k=\sum_{n=0}^{N-1}x_ne^{\frac {-2\pi i}Nnk } \ \ \ k=0,...,N-1\]

重写如下

\[X_0= \sum_{n=0}^{N-1}x_n\]

\[X_{g^{-p}(mod\ N)}=x_0 + \sum_{q=0}^{N-2}x_{g^q}e^{\frac {-2\pi i}Ng^{p-q}(mod\ N)}\]

这个重写 DFT 实际上是将原始 DFT 公式中的 \(k\) 替换为\(g^{-p}(mod\ N)\),将 \(n\) 替换为 \(g^q (mod\ N)\) 。最后我们定义两个 N-1 长度的两个序列 \(a_q\)\(b_q\) 如下

\[a_q = x_{g^q}\]

\[b_q = e^{\frac {-2 \pi i}{N}g^{-q}}\]

注意到 \(\sum_{q=0}^{N-2}x_{g^q}e^{\frac {-2\pi i}Ng^{p-q}(mod\ N)}\) 其实是 \(a_q\)\(b_q\) 的循环卷积。则根据卷积定理

\[a_q * b_q = dft^{-1}(dft(a_q)\ dft(b_q))\]

这样,\(N\) 长度的 DFT 变换便被分解为长度为 \(N-1\) 的两个 DFT 和一个长度为 \(N-1\) 的 DFT 逆变换。当然,如果 \(N-1\) 因式分解后还是有大质数,那还是得继续使用雷德算法接着分解,这样的耗时也会很大。不过对于 \(a_q\)\(b_q\) 的循环卷积,也可以使用补零到 2 的幂次方长度再进行 FFT 实现。显然\(a_q\)\(b_q\) 这两个序列都是 \(N-1\) 长度的

填补方法

注意到 \(a_q\)\(b_q\),可以对 \(a_q\)\(b_q\) 进行扩充得到 \(\hat{a_q}\)\(\hat{b_q}\),使得其长度变为 2 的幂次方。一般是扩充到长度 \(M' \geq 2*N-3\) 且 $ M' $ 是 2 的幂次方。扩充的方法得注意一下,不是直接在后面补零就可以。

  • 对于序列 \(a_q\),在第 0 个元素和第 1 个元素间填补 \(M'-N+1\) 个零得到 \(\hat{a_q}\)
  • 对于序列 \(b_q\),在其后循环重复 \(b_q\) 本身直到其长度为 \(M'\) 得到 \(\hat{b_q}\)

\(M=N-1\) 我们可以简单地证明一下正确性:

\[\begin{split}\hat{a_q} * \hat{b_q}[n] =& \sum_{m= -\infty}^{\infty}\hat{a_q}[m]\hat{b_q}[n-m] \\\ =&a_q[0]\cdot \hat{b_q}[n]+0 \cdot \hat{b_q}[n-1]+\cdots + 0 \cdot \hat{b_q}[n-M'+M] + \\\ &a_q[1] \cdot \hat{b_q}[n-M'+M-1]+a_q[2] \cdot \hat{b_q}[n-M'+M-2] \\\ &+ \cdots + a_q[M-1] \cdot \hat{b_q}[n-M] \\\ =&a_q[0] \cdot b_q[n] + a_q[1] \cdot b_q[n-1]+ \cdots +a_q[n]b_q[n-M] \\\ =&a_q * b_q [n] \end{split}\]

寻找原根

原根定义如下:对于两个正整数 \(a\ mod\ N=1\)。定义 \(a\)\(N\) 的阶数为使得 \(a^d = 1(mod \ N)\) 成立的最小正整数 \(d\)。如果该阶数为 \(\phi(N)=d\),则称 \(a\) 为原根(对没错,这个 \(\phi\) 就是我之前这篇博文中的欧拉函数)。

目前没有直接方法能够计算得到原根,基本上只能按照原根定义公式去判断这个是不是原根。因为雷德算法只处理质数长度的 FFT,显然有 \(\phi(N) = N-1\)。对于要判断一个数 \(a\) 是不是原根,我们要对 \(d\)\(1\) 试到 \(N-1\),如果只有 \(d=N-1\)\(a^d = 1(mod \ N)\) 才成立,则 \(a\) 是原根。

当然有快一点的方法,首先我们将 \(N-1\) 进行因式分解,得到 \(k\) 个质因数 \(N-1=p_1p_2...p_k\),那么如果要判断 \(a\) 是不是原根。只需要判断下面的式子

\[ a^{ \frac {N-1} {p_i} }\ mod \ N \]

对所有的 \(p_i\) 都不等于1,就可以认为 \(a\) 是原根。这里,简单证明一下(感谢竺可桢学院的大仙)

首先我们证明一下引理:如果 \(a \ mod \ N =1,a^d=1(mod \ N)\),则 \(a\)\(N\) 的阶数 \(Order_N(a)\) 能被 \(d\) 整除

\(d=Order_N(a) \cdot k+x,x < Order_N(a)\),则 \(a^{Order_N(a) \cdot k+x}=1(mod \ N)\),由 \(x < Order_N(a)\) 和阶的定义知 \(x=0\)

现在来证明一下这个判别方法:如果 \(a\) 不是原根,那么 \(0<Order_N(a)<N-1\),根据费马小定理有 \(a^{N-1} = 1(mod \ N)\)。根据刚才的引理可以有 \(Order_N(a)\) 能被 \(N-1\) 整除。所以\(Order_N(a)\) 至少在 \(\frac {N-1} {p_1},\frac {N-1} {p_2}, \cdots ,\frac {N-1} {p_k}\) 有一个。因此存在一个 \({p_i}\) 使得 \(a^{ \frac {N-1} {p_i}}=1(mod \ N)\)。得证。

C++ 代码实现

瞎逼逼了这么多,对伸手党提供一下 C++ 的代码,包含了快速傅立叶变换和逆变换。使用了纯模版编写,方便兼容各种浮点型。对于长度为 8 的变换直接进行计算,加快了速度。

FFT.h

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
//
// FFT.h
// 该文件提供了FFT相关函数
//

#ifndef FFT_H
#define FFT_H

#include <vector>
#include <deque>
#include <numeric>
#include <complex>

#include "Number.h"


#ifndef M_PI
#define M_PI 3.141592653589793238462643
#endif

#ifndef DOUBLE_PI
#define DOUBLE_PI 6.283185307179586476925287
#endif

//函数声明
namespace FFT {

/**
* fft计算,当输入数据长度是2的幂时使用基2-FFT算法迭代实现,复杂度O(nlogn);否则使用雷德算法和库利-图基混合基算法
*
* @param data 输入数据(复数)
*
* @return 输出数据(复数)
*
* @see fft(const std::vector<std::complex<T> > &data,size_t N)
* @see fft(const std::vector<T> &data)
* @see recursiveFFT(const std::vector<std::complex<T> > &data)
* @see fft(const std::vector<T> &data, size_t N)
*/
template <typename T> std::vector<std::complex<T>> fft(const std::vector<std::complex<T> > &data);

/**
* fft计算, 当输入数据长度是2的幂时使用基2-FFT算法迭代实现,复杂度O(nlogn);否则使用雷德算法和库利-图基混合基算法
*
* @param data 输入数据(自动转换为复数)
*
* @return 输出数据(复数)
*
* @see fft(const std::vector<std::complex<T> > &data)
* @see fft(const std::vector<std::complex<T> > &data, size_t N)
* @see recursiveFFT(const std::vector<std::complex<T> > &data)
* @see fft(const std::vector<T> &data, size_t N)
*/
template <typename T> std::vector<std::complex<T> > fft(const std::vector<T> &data);

/**
* fft计算, 当输入数据长度是2的幂时使用基2-FFT算法迭代实现,复杂度O(nlogn);否则使用雷德算法和库利-图基混合基算法
*
* @param data 输入数据(复数)
* @param N 长度
*
* @return 输出数据(复数)
*
* @see fft(const std::vector<std::complex<T>> &data)
* @see fft(const std::vector<T> &data)
* @see recursiveFFT(const std::vector<std::complex<T>> &data)
* @see fft(const std::vector<T> &data,size_t N)
*/
template <typename T> std::vector<std::complex<T> > fft(const std::vector<std::complex<T> > &data, size_t N);

/**
* fft计算, 当输入数据长度是2的幂时使用基2-FFT算法迭代实现,复杂度O(nlogn);否则使用雷德算法和库利-图基混合基算法
*
* @param data 输入数据(复数)
* @param N 长度
*
* @return 输出数据(复数)
*
* @see fft(const std::vector<std::complex<T> > &data)
* @see fft(const std::vector<T> &data)
* @see recursiveFFT(const std::vector<std::complex<T> > &data)
* @see template <typename T> void fft(const std::vector<std::complex<T> > &data,size_t N)
*/
template <typename T> std::vector<std::complex<T>> fft(const std::vector<T> &data, size_t N);

/**
* ifft计算,直接使用 data 的长度作为运算点数
* @tparam T 可以是任意浮点型
* @param data 输入复数
*
* @return result 结果
*
* @see ifft(const std::vector<T> &data)
* @see ifft(const std::vector<std::complex<T> > &data, size_t N)
*/
template <typename T> std::vector<std::complex<T>> ifft(const std::vector<std::complex<T>> &data);

/**
* ifft计算,直接使用 data 的长度作为运算点数
* @tparam T 可以是任意浮点型
* @param data 输入实数,自动转换为复数运算
*
* @return 结果(复数)
*
* @see ifft(const std::vector<std::complex<T>> &data)
* @see ifft(const std::vector<std::complex<T> > &data, size_t N)
*/
template <typename T> std::vector<std::complex<T>> ifft(const std::vector<T> &data);

/**
* ifft计算,当输入数据长度是2的幂时使用基2-FFT算法迭代实现,复杂度O(nlogn);
*
* @param data 输入数据
* @param N 长度
*
* @return result 输出数据
*
* @see ifft(const std::vector<std::complex<T>> &data)
* @see ifft(const std::vector<T> &data)
*/
template <typename T> std::vector<std::complex<T> > ifft(const std::vector<std::complex<T> > &data, size_t N);

/**
* ifft计算,当输入数据长度是2的幂时使用基2-FFT算法迭代实现,复杂度O(nlogn);
*
* @param data 输入数据
* @param N 长度
*
* @return 输出数据(复数)
*
* @see ifft(const std::vector<std::complex<T> > &data, size_t N)
* @see ifft(const std::vector<T> &data)
* @see ifft(const std::vector<std::complex<T>> &data)
*/
template <typename T> std::vector<std::complex<T>> ifft(const std::vector<T> &data, size_t N);

namespace FFTInner {

/**
* radix-2 库利-图基迭代计算fft,参照算法导论中有关FFT部分
*
* @param data 原始数据(复数)
*
* @return fft后的结果(复数)
*/
template <typename T> std::vector<std::complex<T>> radix2FFT(const std::vector<std::complex<T>> &data);

/**
* 固定的8点fft算法,非递归的蝶形算法
*
* @param data 原始数据(复数)
*
* @return fft后的结果(复数)
*/
template <typename T> std::vector<std::complex<T>> fft8(const std::vector<std::complex<T>> &data);

/**
* 迭代计算ifft,参照算法导论中有关FFT部分,注意这个结果需要除以 N 才行!
*
* @tparam T 必须是浮点型
* @param data 原始数据(复数)
*
* @return ifft后的结果(复数)
*/
template <typename T> std::vector<std::complex<T>> radix2IFFT(const std::vector<std::complex<T>> &data);

/**
* 固定的8点ifft算法,非递归的蝶形算法
*
* @param data 原始数据(复数)
*
* @return ifft后的结果(复数)
*/
template <typename T> std::vector<std::complex<T>> ifft8(const std::vector<std::complex<T>> &data);

/**
* 雷德算法 fft
* @tparam T 必须是浮点型
* @param data 长度必须是素数
*
* @return 雷德算法结果(复数)
*/
template <typename T> std::vector<std::complex<T>> raderFFT(const std::vector<std::complex<T>> &data);

/**
* 库利-图基混合基算法加雷德算法 fft
*
* @tparam T 必须是浮点型
* @param data
*
* @return fft结果(复数)
*/
template <typename T> std::vector<std::complex<T>> hybridFFT(const std::vector<std::complex<T>> &data);

/**
* 雷德算法 ifft
* @tparam T 必须是浮点型
* @param data 长度必须是素数
*
* @return ifft结果
*/
template <typename T> std::vector<std::complex<T>> raderIFFT(const std::vector<std::complex<T>> &data);

/**
* 库利-图基混合基算法加雷德算法 ifft
* @tparam T
* @param data
*
* @return ifft结果
*/
template <typename T> std::vector<std::complex<T>> hybridIFFT(const std::vector<std::complex<T>> &data);

/**
* 计算比n大的最小的2的幂
*
* @param n n
*
* @return 最小的2的幂
*/
template <typename T> T nextPowerOf2(T n);

/**
* 将各类实数(浮点数和整数)转换为复数
*
* @param data 实数
*
* @return 复数
*/
template <typename T> std::vector<std::complex<T>> toComplex(const std::vector<T> &data);

}
}


//函数定义
namespace FFT {

template <typename T> std::vector<std::complex<T>> fft(const std::vector<T> &data) {
static_assert(std::is_floating_point<T>::value, "T must float type!");

return fft(FFTInner::toComplex(data));
}

template <typename T> std::vector<std::complex<T>> fft(const std::vector<std::complex<T>> &data) {
static_assert(std::is_floating_point<T>::value, "T must float type!");

if (FFTInner::nextPowerOf2(data.size()) != data.size())
return FFTInner::hybridFFT(data);
else
return FFTInner::radix2FFT(data);
}

template <typename T> std::vector<std::complex<T>> fft(const std::vector<std::complex<T>> &data, size_t N) {
static_assert(std::is_floating_point<T>::value, "T must float type!");

std::vector <std::complex<T>> da = data;
da.resize(N, std::complex<T>(0.0, 0.0));
if (FFTInner::nextPowerOf2(N) != N) {
return FFTInner::hybridFFT(da);
} else {
return FFTInner::radix2FFT(da);
}
}

template <typename T> std::vector<std::complex<T> > fft(const std::vector<T> &data, size_t N) {
static_assert(std::is_floating_point<T>::value, "T must float type!");

return fft(FFTInner::toComplex(data), N);
}

template <typename T> std::vector<std::complex<T>> ifft(const std::vector<std::complex<T>> &data){
static_assert(std::is_floating_point<T>::value, "T must float type!");

return ifft(data, data.size());
}

template <typename T> std::vector<std::complex<T>> ifft(const std::vector<T> &data){
static_assert(std::is_floating_point<T>::value, "T must float type!");

return ifft(FFTInner::toComplex(data), data.size());
}

template <typename T> std::vector<std::complex<T>> ifft(const std::vector<T> &data, size_t N) {
static_assert(std::is_floating_point<T>::value, "T must float type!");

return ifft(FFTInner::toComplex(data), N);
}

template <typename T> std::vector<std::complex<T>> ifft(const std::vector<std::complex<T> > &data, size_t N) {
static_assert(std::is_floating_point<T>::value, "T must float type!");


std::vector <std::complex<T>> da = data;
da.resize(N, std::complex<T>(0.0, 0.0));
if (FFTInner::nextPowerOf2(N) != N) {
return FFTInner::hybridIFFT(da);
} else {
auto result = FFTInner::radix2IFFT(da);
for (auto it = result.begin(); it != result.end(); ++it)
(*it) = std::complex<T>((*it).real() / N, (*it).imag() / N);
return result;
}
}

namespace FFTInner {

template <typename T> std::vector<std::complex<T>> fft8(const std::vector<std::complex<T>> &data) {
static_assert(std::is_floating_point<T>::value, "T must float type!");
static const std::vector<std::complex<T> > wn = {
std::complex<T>(1.0, 0.0),
std::complex<T>(cos(DOUBLE_PI / 8), -sin(DOUBLE_PI / 8)),
std::complex<T>(cos(DOUBLE_PI / 4), -sin(DOUBLE_PI / 4)),
std::complex<T>(cos(DOUBLE_PI * 3 / 8), -sin(DOUBLE_PI * 3 / 8)) };

std::vector<std::complex<T> > colTmp1, colTmp2, result;
colTmp1.reserve(8);
colTmp2.reserve(8);
result.reserve(8);

//DIT FFT
colTmp1.emplace_back(data[0] + data[4]);
colTmp1.emplace_back(data[0] - data[4]);
colTmp1.emplace_back(data[2] + data[6]);
colTmp1.emplace_back((data[2] - data[6]) * wn[2]);
colTmp1.emplace_back(data[1] + data[5]);
colTmp1.emplace_back(data[1] - data[5]);
colTmp1.emplace_back(data[3] + data[7]);
colTmp1.emplace_back((data[3] - data[7]) * wn[2]);

colTmp2.emplace_back(colTmp1[0] + colTmp1[2]);
colTmp2.emplace_back(colTmp1[1] + colTmp1[3]);
colTmp2.emplace_back(colTmp1[0] - colTmp1[2]);
colTmp2.emplace_back(colTmp1[1] - colTmp1[3]);
colTmp2.emplace_back(colTmp1[4] + colTmp1[6]);
colTmp2.emplace_back((colTmp1[5] + colTmp1[7]) * wn[1]);
colTmp2.emplace_back((colTmp1[4] - colTmp1[6]) * wn[2]);
colTmp2.emplace_back((colTmp1[5] - colTmp1[7]) * wn[3]);

result.emplace_back(colTmp2[0] + colTmp2[4]);
result.emplace_back(colTmp2[1] + colTmp2[5]);
result.emplace_back(colTmp2[2] + colTmp2[6]);
result.emplace_back(colTmp2[3] + colTmp2[7]);
result.emplace_back(colTmp2[0] - colTmp2[4]);
result.emplace_back(colTmp2[1] - colTmp2[5]);
result.emplace_back(colTmp2[2] - colTmp2[6]);
result.emplace_back(colTmp2[3] - colTmp2[7]);

return result;
}

template <typename T> std::vector<std::complex<T>> ifft8(const std::vector<std::complex<T>> &data) {
static_assert(std::is_floating_point<T>::value, "T must float type!");

static const std::vector<std::complex<T> > wn = {
std::complex<T>(1.0, 0.0),
std::complex<T>(cos(DOUBLE_PI / 8),sin(DOUBLE_PI / 8)),
std::complex<T>(cos(DOUBLE_PI / 4),sin(DOUBLE_PI / 4)),
std::complex<T>(cos(DOUBLE_PI * 3 / 8),sin(DOUBLE_PI * 3 / 8)) };

std::vector<std::complex<T> > colTmp1, colTmp2, result;
colTmp1.reserve(8);
colTmp2.reserve(8);
result.reserve(8);

//DIT IFFT
colTmp1.emplace_back(data[0] + data[4]);
colTmp1.emplace_back(data[0] - data[4]);
colTmp1.emplace_back(data[2] + data[6]);
colTmp1.emplace_back((data[2] - data[6]) * wn[2]);
colTmp1.emplace_back(data[1] + data[5]);
colTmp1.emplace_back(data[1] - data[5]);
colTmp1.emplace_back(data[3] + data[7]);
colTmp1.emplace_back((data[3] - data[7]) * wn[2]);

colTmp2.emplace_back(colTmp1[0] + colTmp1[2]);
colTmp2.emplace_back(colTmp1[1] + colTmp1[3]);
colTmp2.emplace_back(colTmp1[0] - colTmp1[2]);
colTmp2.emplace_back(colTmp1[1] - colTmp1[3]);
colTmp2.emplace_back(colTmp1[4] + colTmp1[6]);
colTmp2.emplace_back((colTmp1[5] + colTmp1[7]) * wn[1]);
colTmp2.emplace_back((colTmp1[4] - colTmp1[6]) * wn[2]);
colTmp2.emplace_back((colTmp1[5] - colTmp1[7]) * wn[3]);

result.emplace_back(colTmp2[0] + colTmp2[4]);
result.emplace_back(colTmp2[1] + colTmp2[5]);
result.emplace_back(colTmp2[2] + colTmp2[6]);
result.emplace_back(colTmp2[3] + colTmp2[7]);
result.emplace_back(colTmp2[0] - colTmp2[4]);
result.emplace_back(colTmp2[1] - colTmp2[5]);
result.emplace_back(colTmp2[2] - colTmp2[6]);
result.emplace_back(colTmp2[3] - colTmp2[7]);
return result;
}

template <typename T> std::vector<std::complex<T>> radix2FFT(const std::vector<std::complex<T>> &data) {
static_assert(std::is_floating_point<T>::value, "T must float type!");

auto n = data.size();
if (8 == n) {
return fft8(data);
}
if (1 == n) {
return data;
}
std::complex<T> wn(cos(DOUBLE_PI / n), -sin(DOUBLE_PI / n));
std::complex<T> w(1, 0);
typename std::vector<std::complex<T> > a0, a1;
a0.reserve(data.size() / 2);
a1.reserve(data.size() / 2);
for (auto it = data.cbegin(); it != data.cend();) {
a0.emplace_back((*it));
++it;
a1.emplace_back((*it));
++it;
}

auto y0 = radix2FFT(a0);
auto y1 = radix2FFT(a1);
std::vector<std::complex<T>> result(n, std::complex<T>(0., 0.));
for (size_t k = 0; k <= n / 2 - 1; k++) {
result[k] = y0[k] + w * y1[k];
result[k + n / 2] = y0[k] - w * y1[k];
w = w*wn;
}
return result;
}

template <typename T> std::vector<std::complex<T>> radix2IFFT(const std::vector<std::complex<T>> &data) {
static_assert(std::is_floating_point<T>::value, "T must float type!");

auto n = data.size();
if (8 == n) {
return ifft8(data);
}
if (1 == n) {
return data;
}
std::complex<T> wn(cos(-DOUBLE_PI / n), -sin(-DOUBLE_PI / n));
std::complex<T> w(1, 0);
typename std::vector<std::complex<T> > a0, a1;
a0.reserve(data.size() / 2);
a1.reserve(data.size() / 2);
for (auto it = data.cbegin(); it != data.cend();) {
a0.emplace_back((*it));
++it;
a1.emplace_back((*it));
++it;
}

auto y0 = radix2IFFT(a0);
auto y1 = radix2IFFT(a1);
std::vector<std::complex<T>> result(n, std::complex<T>(0., 0.));
for (size_t k = 0; k <= n / 2 - 1; k++) {
result[k] = y0[k] + w * y1[k];
result[k + n / 2] = y0[k] - w * y1[k];
w = w * wn;
}
return result;
}

template <typename T> std::vector<std::complex<T>> raderFFT(const std::vector<std::complex<T>> &data) {
static_assert(std::is_floating_point<T>::value, "T must float type!");

auto N = data.size();
auto X0 = std::accumulate(data.cbegin(), data.cend(), std::complex<T>(0.0, 0.0));
auto g = Number::findPrimitiveRoot(data.size());
std::vector<std::complex<T>> aq, bq, product;
std::vector<size_t> aqIndex, bqIndex;

aqIndex.emplace_back(1);
bqIndex.emplace_back(1);
aq.emplace_back(data[1]);
bq.emplace_back(std::complex<T>(cos(DOUBLE_PI/N), -sin(DOUBLE_PI/N)));
auto exp = Number::expModuloN(g, static_cast<size_t>(1), N);
auto expInverse = Number::expModuloNInverse(g, static_cast<size_t>(1), N);
auto expInverseBase = expInverse;

for (size_t index = 1; index <= N-2; index++){
aqIndex.emplace_back(exp);
bqIndex.emplace_back(expInverse);

aq.emplace_back(data[exp]);
auto tmp = expInverse*DOUBLE_PI/N;
bq.emplace_back(std::complex<T>(cos(tmp), -sin(tmp)));

exp = (exp * g) % N;
expInverse = (expInverse * expInverseBase) % N;
}

// 补零
auto M = FFTInner::nextPowerOf2(2*N-3);
if (M != N-1) {
aq.insert(aq.begin()+1, M-N+1, std::complex<T>(0.0, 0.0));
for (size_t index = 0; index < M-N+1; index++)
bq.emplace_back(bq[index]);
}

auto faq = radix2FFT(aq);
auto fbq = radix2FFT(bq);
for (size_t index = 0; index <= M-1; index++){
product.emplace_back(faq[index]*fbq[index]);
}
auto inverseDFT = radix2IFFT(product);
std::vector<std::complex<T>> result(N, std::complex<T>(0.0, 0.0));
result[0] = X0;

for (size_t index = 0; index < N-1; index++)
result[bqIndex[index]] = inverseDFT[index] / static_cast<T>(M) + data[0];

return result;
}

template <typename T> T nextPowerOf2(T n) {
if (n < 0)
return n;
unsigned int p = 1;
if (n && !(n & (n - 1)))
return n;

while (p < n) {
p <<= 1;
}
return p;
}

template <typename T> std::vector<std::complex<T>> raderIFFT(const std::vector<std::complex<T>> &data) {
static_assert(std::is_floating_point<T>::value, "T must float type!");

auto N = data.size();
auto X0 = std::accumulate(data.cbegin(), data.cend(), std::complex<T>(0.0, 0.0));
auto g = Number::findPrimitiveRoot(data.size());
std::vector<std::complex<T>> aq, bq, product;
std::vector<size_t> aqIndex, bqIndex;

aqIndex.emplace_back(1);
bqIndex.emplace_back(1);
aq.emplace_back(data[1]);
bq.emplace_back(std::complex<T>(cos(DOUBLE_PI/N), sin(DOUBLE_PI/N)));

auto exp = Number::expModuloN(g, static_cast<size_t>(1), N);
auto expInverse = Number::expModuloNInverse(g, static_cast<size_t>(1), N);
auto expInverseBase = expInverse;
for (size_t index = 1; index <= N-2; index++){
aqIndex.emplace_back(exp);
bqIndex.emplace_back(expInverse);

aq.emplace_back(data[exp]);
auto tmp = expInverse * DOUBLE_PI/N;
bq.emplace_back(std::complex<T>(cos(tmp), sin(tmp)));

exp = (exp * g) % N;
expInverse = (expInverse * expInverseBase) % N;
}

// 补零
auto M = FFTInner::nextPowerOf2(2*N-3);
if (M != N-1) {
aq.insert(aq.begin()+1, M-N+1, std::complex<T>(0.0, 0.0));
for (size_t index = 0; index < M-N+1; index++)
bq.emplace_back(bq[index]);
}

auto faq = radix2FFT(aq);
auto fbq = radix2FFT(bq);
for (size_t index = 0; index <= M-1; index++){
product.emplace_back(faq[index]*fbq[index]);
}
auto inverseDFT = radix2IFFT(product);
std::vector<std::complex<T>> result(N, std::complex<T>(0.0, 0.0));
result[0] = X0/static_cast<T>(N);

for (size_t index = 0; index < N-1; index++)
result[bqIndex[index]] = (inverseDFT[index] / static_cast<T>(M) + data[0])/static_cast<T>(N);

return result;
}

template <typename T> std::vector<std::complex<T>> hybridFFT(const std::vector<std::complex<T>> &data) {
static_assert(std::is_floating_point<T>::value, "T must float type!");

auto N = data.size();
if (N == 1 || N == 2) {
return radix2FFT(data);
}

// 如果N是质数
auto factors = Number::factor<size_t>(N);
if (factors.size() == 1) {
return raderFFT(data);
}

// 生成N1和N2,使得 N=N1*N2
size_t N1 = factors[0], N2 = N/N1;

std::complex<T> **X;
X = new std::complex<T>*[N1];
for (size_t i = 0; i < N1; i++)
X[i] = new std::complex<T>[N2];
for (size_t n1 = 0; n1 < N1; n1++)
for (size_t n2 = 0; n2 < N2; n2++)
X[n1][n2] = data[N1*n2+n1];
for (size_t n1=0;n1<N1;n1++) {
std::vector<std::complex<T>>row;
row.reserve(N2);
for (size_t i = 0;i < N2;i++)
row.emplace_back(X[n1][i]);
auto tmp = fft(row);
for (size_t n2 = 0; n2 < N2; n2++)
X[n1][n2] = tmp[n2] * std::complex<T>(cos(DOUBLE_PI*n1*n2/N),-sin(DOUBLE_PI*n1*n2/N));
}

for (size_t n2 = 0;n2 < N2;n2++) {
std::vector<std::complex<T>> col;
col.reserve(N1);
for (size_t n1=0;n1<N1;n1++)
col.emplace_back(X[n1][n2]);
auto tmp = fft(col);
for (size_t n1=0;n1<N1;n1++)
X[n1][n2] = tmp[n1];
}

std::vector<std::complex<T>> result(data.size(), std::complex<T>(0.0, 0.0));
for (size_t n1 = 0; n1 < N1; n1++)
for (size_t n2 = 0; n2 < N2; n2++)
result[N2*n1+n2] = X[n1][n2];

for (size_t i = 0; i < N1; i++)
delete [] X[i];
delete [] X;
return result;
}

template <typename T> std::vector<std::complex<T>> hybridIFFT(const std::vector<std::complex<T>> &data) {
static_assert(std::is_floating_point<T>::value, "T must float type!");

auto N = data.size();
if (N == 1 || N == 2) {
auto result = radix2IFFT(data);
for (auto item:result)
item = item / static_cast<T>(N);
return result;
}

// 如果N是质数
auto factors = Number::factor<size_t>(N);
if (factors.size() == 1) {
return raderIFFT(data);
}

// 生成N1和N2,使得 N=N1*N2
size_t N1 = factors[0], N2 = N/N1;

std::complex<T> **X;
X = new std::complex<T>*[N1];
for (size_t i = 0; i < N1; i++)
X[i] = new std::complex<T>[N2];
for (size_t n1 = 0; n1 < N1; n1++)
for (size_t n2 = 0; n2 < N2; n2++)
X[n1][n2] = data[N1*n2+n1];
for (size_t n1 = 0;n1 < N1;n1++) {
std::vector<std::complex<T>> row;
row.reserve(N2);
for (size_t i = 0;i < N2;i++)
row.emplace_back(X[n1][i]);
auto tmp = ifft(row);
for (size_t n2 = 0; n2 < N2; n2++)
X[n1][n2] = tmp[n2] * std::complex<T>(cos(DOUBLE_PI*n1*n2/N), sin(DOUBLE_PI*n1*n2/N))*static_cast<T>(N2);
}

for (size_t n2 = 0;n2 < N2;n2++) {
std::vector<std::complex<T>> col;
col.reserve(N1);
for (size_t n1 = 0;n1 < N1;n1++)
col.emplace_back(X[n1][n2]);
auto tmp = ifft(col);
for (size_t n1 = 0;n1 < N1;n1++)
X[n1][n2] = tmp[n1]*static_cast<T>(N1);
}

std::vector<std::complex<T>> result(data.size(), std::complex<T>(0.0, 0.0));
for (size_t n1 = 0; n1 < N1; n1++)
for (size_t n2 = 0; n2 < N2; n2++)
result[N2*n1+n2] = X[n1][n2]/static_cast<T>(N);

for (size_t i = 0; i < N1; i++)
delete [] X[i];
delete [] X;
return result;
}

template <typename T> std::vector<std::complex<T>> toComplex(const std::vector<T> &data) {
std::vector<std::complex<T>> result;
for (auto it = data.cbegin(); it != data.cend(); it++) {
result.emplace_back(std::complex<T>(*it, 0.0));
}
return result;
}

}
}

#endif

Number.h

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
//
// Number.h
// 该文件提供了数论相关函数
//

#ifndef Number_H
#define Number_H


#include <vector>
#include <deque>
#include <cmath>


//函数声明
namespace Number {

/**
* 找到质数 n 的一个原根,注意 matlab 中没有这个函数
* @tparam T 整数类型
* @param n 待求
* @return
*/
template <typename T> T findPrimitiveRoot(T n);

/**
* 判断一个整数 n 是不是素数,简单算法,TODO使用AKS素数测试
* @tparam T 整数类型
* @param n 待判断整数
*
* @return 是否为素数
*/
template <typename T> bool isprime(T n);

/**
* 对 n 进行因式分解
* @tparam T 整数类型
* @param n 待分解数
*
* @return 结果
*/
template <typename T> std::vector<T> factor(T n);

/**
* 判断一个整数 a 是不是模 m 的原根,m 应该为素数。注意 matlab 中没有这个函数
* @tparam T
* @param a 待判断整数
* @param m 模
*
* @return 是否为原根
*/
template <typename T> bool isPrimitiveRoot(T p, T m);

/**
* 模幂运算 (a^k)%n 的较快速算法,注意 matlab 中没有这个函数
* @tparam T 整数类型
* @param a
* @param k
* @param n
*
* @return (a^k)%n
*/
template <typename T> T expModuloN(T a, T k, T n);

/**
* 求模逆元的模幂运算(a^(-k))%n, 注意 matlab 中没有这个函数
* @tparam T
* @param a
* @param k
* @param n
*
* @return (a^(-k))%n
*/
template <typename T> T expModuloNInverse(T a, T k, T n);
}


namespace Number {

template <typename T> T findPrimitiveRoot(T n) {
static_assert(std::is_unsigned<T>::value, "T must unsigned type!");

if(!isprime(n))
throw std::invalid_argument("n should be prime!");
for (T primeRootCandidate = 2; primeRootCandidate <= n-1; primeRootCandidate++) {
if (isPrimitiveRoot(primeRootCandidate, n)) return primeRootCandidate;
}
throw std::runtime_error("Prime root not found");
}

template <typename T> bool isprime(T n) {
static_assert(std::is_unsigned<T>::value, "T must unsigned type!");

if (n <= 1) throw std::invalid_argument("Prime/Composite test should bigger than 1");
if (n == 2) return true;
if (n % 2 == 0) return false;
T end = static_cast<T>(sqrt(n));
for (T start = 3; start <= end; start+=2 ){
if (n % start == 0) return false;
}
return true;
}

template <typename T> bool isPrimitiveRoot(T p, T m) {
static_assert(std::is_unsigned<T>::value, "T must unsigned type!");

T tot = m - 1;
auto factors = factor(tot);

//TODO faster possible
for (T pi : factors) {
if (expModuloN(p, tot/pi, m) == 1) return false;
}
return true;
}

template <typename T> std::vector<T> factor(T n) {
static_assert(std::is_unsigned<T>::value, "T must unsigned type!");

std::vector<T> result;
if (1 == n) {
result.emplace_back(1);
return result;
}
for (T i = 2; i <= n; i++) {
while (n != i) {
if (n % i == 0) {
result.emplace_back(i);
n = n / i;
} else
break;
}
}
result.emplace_back(n);
return result;
}

template <typename T> T expModuloN(T a, T k, T n) {
static_assert(std::is_unsigned<T>::value, "T must unsigned type!");

if (k == 0) return 1;
if (k == 1) return a % n;
if (k == 2) return (a*a) % n;
T k1 = k / 2;
T k2 = k - k1;
if (k1 < k2) {
T tmp1 = expModuloN(a, k1, n);
T tmp2 = (tmp1*a)%n;
return (tmp1*tmp2) % n;
} else {
T tmp = expModuloN(a, k1, n);
return (tmp*tmp) % n;
}
}

template <typename T> T expModuloNInverse(T a, T k, T n) {
static_assert(std::is_unsigned<T>::value, "T must unsigned type!");

if (k == 0) return 1;
if (k == 1) {
for (T inverse = 0; inverse <= n-1; inverse++) {
if ((a*inverse) % n == 1) return inverse;
}
throw std::runtime_error("modular inverse not found!");
}
T modInverse = expModuloNInverse<T>(a, 1, n);
return expModuloN<T>(modInverse, k, n);
}
}

#endif

参考资料

sky blue trades

wiki

数字信号处理——原理,算法与应用