From 09b47332553a79dab30516e6b1d410dea90cf9b7 Mon Sep 17 00:00:00 2001 From: Mark Borgerding Date: Mon, 25 May 2009 23:52:21 -0400 Subject: [PATCH] added real-optimized inverse FFT (NFFT must be multiple of 4) --- bench/benchFFT.cpp | 30 +- unsupported/Eigen/src/FFT/ei_kissfft_impl.h | 688 ++++++++++---------- 2 files changed, 370 insertions(+), 348 deletions(-) diff --git a/bench/benchFFT.cpp b/bench/benchFFT.cpp index ffa4ffffc..14f5063fb 100644 --- a/bench/benchFFT.cpp +++ b/bench/benchFFT.cpp @@ -53,7 +53,7 @@ template <> string nameof() {return "long double";} using namespace Eigen; template -void bench(int nfft) +void bench(int nfft,bool fwd) { typedef typename NumTraits::Real Scalar; typedef typename std::complex Complex; @@ -69,7 +69,10 @@ void bench(int nfft) for (int k=0;k<8;++k) { timer.start(); for(int i = 0; i < nits; i++) - fft.fwd( outbuf , inbuf); + if (fwd) + fft.fwd( outbuf , inbuf); + else + fft.inv(inbuf,outbuf); timer.stop(); } @@ -82,16 +85,27 @@ void bench(int nfft) mflops /= 2; } + if (fwd) + cout << " fwd"; + else + cout << " inv"; + cout << " NFFT=" << nfft << " " << (double(1e-6*nfft*nits)/timer.value()) << " MS/s " << mflops << "MFLOPS\n"; } int main(int argc,char ** argv) { - bench >(NFFT); - bench(NFFT); - bench >(NFFT); - bench(NFFT); - bench >(NFFT); - bench(NFFT); + bench >(NFFT,true); + bench >(NFFT,false); + bench(NFFT,true); + bench(NFFT,false); + bench >(NFFT,true); + bench >(NFFT,false); + bench(NFFT,true); + bench(NFFT,false); + bench >(NFFT,true); + bench >(NFFT,false); + bench(NFFT,true); + bench(NFFT,false); return 0; } diff --git a/unsupported/Eigen/src/FFT/ei_kissfft_impl.h b/unsupported/Eigen/src/FFT/ei_kissfft_impl.h index 3580e6c61..453c7f6da 100644 --- a/unsupported/Eigen/src/FFT/ei_kissfft_impl.h +++ b/unsupported/Eigen/src/FFT/ei_kissfft_impl.h @@ -28,390 +28,398 @@ namespace Eigen { - template - struct ei_kiss_cpx_fft + template + struct ei_kiss_cpx_fft + { + typedef _Scalar Scalar; + typedef std::complex Complex; + std::vector m_twiddles; + std::vector m_stageRadix; + std::vector m_stageRemainder; + bool m_inverse; + + void make_twiddles(int nfft,bool inverse) { - typedef _Scalar Scalar; - typedef std::complex Complex; - std::vector m_twiddles; - std::vector m_stageRadix; - std::vector m_stageRemainder; - bool m_inverse; + m_inverse = inverse; + m_twiddles.resize(nfft); + Scalar phinc = (inverse?2:-2)* acos( (Scalar) -1) / nfft; + for (int i=0;in) + p=n;// impossible to have a factor > sqrt(n) + } + n /= p; + m_stageRadix.push_back(p); + m_stageRemainder.push_back(n); + }while(n>1); + } + + template + void work( int stage,Complex * xout, const _Src * xin, size_t fstride,size_t in_stride) + { + int p = m_stageRadix[stage]; + int m = m_stageRemainder[stage]; + Complex * Fout_beg = xout; + Complex * Fout_end = xout + p*m; + + if (m>1) { + do{ + // recursive call: + // DFT of size m*p performed by doing + // p instances of smaller DFTs of size m, + // each one takes a decimated version of the input + work(stage+1, xout , xin, fstride*p,in_stride); + xin += fstride*in_stride; + }while( (xout += m) != Fout_end ); + }else{ + do{ + *xout = *xin; + xin += fstride*in_stride; + }while(++xout != Fout_end ); + } + xout=Fout_beg; + + // recombine the p smaller DFTs + switch (p) { + case 2: bfly2(xout,fstride,m); break; + case 3: bfly3(xout,fstride,m); break; + case 4: bfly4(xout,fstride,m); break; + case 5: bfly5(xout,fstride,m); break; + default: bfly_generic(xout,fstride,m,p); break; + } + } + + void bfly2( Complex * Fout, const size_t fstride, int m) + { + for (int k=0;kreal() - .5*scratch[3].real() , Fout->imag() - .5*scratch[3].imag() ); + scratch[0] *= epi3.imag(); + *Fout += scratch[3]; + Fout[m2] = Complex( Fout[m].real() + scratch[0].imag() , Fout[m].imag() - scratch[0].real() ); + Fout[m] += Complex( -scratch[0].imag(),scratch[0].real() ); + ++Fout; + }while(--k); + } + + void bfly5( Complex * Fout, const size_t fstride, const size_t m) + { + Complex *Fout0,*Fout1,*Fout2,*Fout3,*Fout4; + size_t u; + Complex scratch[13]; + Complex * twiddles = &m_twiddles[0]; + Complex *tw; + Complex ya,yb; + ya = twiddles[fstride*m]; + yb = twiddles[fstride*2*m]; + + Fout0=Fout; + Fout1=Fout0+m; + Fout2=Fout0+2*m; + Fout3=Fout0+3*m; + Fout4=Fout0+4*m; + + tw=twiddles; + for ( u=0; u=Norig) twidx-=Norig; + t=scratchbuf[q] * twiddles[twidx]; + Fout[ k ] += t; + } + k += m; } - - void factorize(int nfft) - { - if (m_stageRadix.size()==0 || m_stageRadix[0] * m_stageRemainder[0] != nfft) - { - m_stageRadix.resize(0); - m_stageRemainder.resize(0); - //factorize - //start factoring out 4's, then 2's, then 3,5,7,9,... - int n= nfft; - int p=4; - do { - while (n % p) { - switch (p) { - case 4: p = 2; break; - case 2: p = 3; break; - default: p += 2; break; - } - if (p*p>n) - p=n;// impossible to have a factor > sqrt(n) - } - n /= p; - m_stageRadix.push_back(p); - m_stageRemainder.push_back(n); - }while(n>1); - } - } - - template - void work( int stage,Complex * xout, const _Src * xin, size_t fstride,size_t in_stride) - { - int p = m_stageRadix[stage]; - int m = m_stageRemainder[stage]; - Complex * Fout_beg = xout; - Complex * Fout_end = xout + p*m; - - if (m>1) { - do{ - // recursive call: - // DFT of size m*p performed by doing - // p instances of smaller DFTs of size m, - // each one takes a decimated version of the input - work(stage+1, xout , xin, fstride*p,in_stride); - xin += fstride*in_stride; - }while( (xout += m) != Fout_end ); - }else{ - do{ - *xout = *xin; - xin += fstride*in_stride; - }while(++xout != Fout_end ); - } - xout=Fout_beg; - - // recombine the p smaller DFTs - switch (p) { - case 2: bfly2(xout,fstride,m); break; - case 3: bfly3(xout,fstride,m); break; - case 4: bfly4(xout,fstride,m); break; - case 5: bfly5(xout,fstride,m); break; - default: bfly_generic(xout,fstride,m,p); break; - } - } - - void bfly2( Complex * Fout, const size_t fstride, int m) - { - for (int k=0;kreal() - .5*scratch[3].real() , Fout->imag() - .5*scratch[3].imag() ); - scratch[0] *= epi3.imag(); - *Fout += scratch[3]; - Fout[m2] = Complex( Fout[m].real() + scratch[0].imag() , Fout[m].imag() - scratch[0].real() ); - Fout[m] += Complex( -scratch[0].imag(),scratch[0].real() ); - ++Fout; - }while(--k); - } - - void bfly5( Complex * Fout, const size_t fstride, const size_t m) - { - Complex *Fout0,*Fout1,*Fout2,*Fout3,*Fout4; - size_t u; - Complex scratch[13]; - Complex * twiddles = &m_twiddles[0]; - Complex *tw; - Complex ya,yb; - ya = twiddles[fstride*m]; - yb = twiddles[fstride*2*m]; - - Fout0=Fout; - Fout1=Fout0+m; - Fout2=Fout0+2*m; - Fout3=Fout0+3*m; - Fout4=Fout0+4*m; - - tw=twiddles; - for ( u=0; u=Norig) twidx-=Norig; - t=scratchbuf[q] * twiddles[twidx]; - Fout[ k ] += t; - } - k += m; - } - } - } - }; - + } + } + }; template struct ei_kissfft_impl { - typedef _Scalar Scalar; - typedef std::complex Complex; - ei_kissfft_impl() {} + typedef _Scalar Scalar; + typedef std::complex Complex; - void clear() - { + void clear() + { m_plans.clear(); m_realTwiddles.clear(); - } + } - template - void fwd( Complex * dst,const _Src *src,int nfft) - { + template + void fwd( Complex * dst,const _Src *src,int nfft) + { get_plan(nfft,false).work(0, dst, src, 1,1); - } + } - // real-to-complex forward FFT - // perform two FFTs of src even and src odd - // then twiddle to recombine them into the half-spectrum format - // then fill in the conjugate symmetric half - void fwd( Complex * dst,const Scalar * src,int nfft) - { + // real-to-complex forward FFT + // perform two FFTs of src even and src odd + // then twiddle to recombine them into the half-spectrum format + // then fill in the conjugate symmetric half + void fwd( Complex * dst,const Scalar * src,int nfft) + { if ( nfft&3 ) { - // use generic mode for odd - get_plan(nfft,false).work(0, dst, src, 1,1); + // use generic mode for odd + get_plan(nfft,false).work(0, dst, src, 1,1); }else{ - int ncfft = nfft>>1; - int ncfft2 = nfft>>2; - Complex * rtw = real_twiddles(ncfft2); + int ncfft = nfft>>1; + int ncfft2 = nfft>>2; + Complex * rtw = real_twiddles(ncfft2); - // use optimized mode for even real - fwd( dst, reinterpret_cast (src), ncfft); - Complex dc = dst[0].real() + dst[0].imag(); - Complex nyquist = dst[0].real() - dst[0].imag(); - int k; - for ( k=1;k <= ncfft2 ; ++k ) { - Complex fpk = dst[k]; - Complex fpnk = conj(dst[ncfft-k]); - Complex f1k = fpk + fpnk; - Complex f2k = fpk - fpnk; - Complex tw= f2k * rtw[k-1]; + // use optimized mode for even real + fwd( dst, reinterpret_cast (src), ncfft); + Complex dc = dst[0].real() + dst[0].imag(); + Complex nyquist = dst[0].real() - dst[0].imag(); + int k; + for ( k=1;k <= ncfft2 ; ++k ) { + Complex fpk = dst[k]; + Complex fpnk = conj(dst[ncfft-k]); + Complex f1k = fpk + fpnk; + Complex f2k = fpk - fpnk; + Complex tw= f2k * rtw[k-1]; + dst[k] = (f1k + tw) * Scalar(.5); + dst[ncfft-k] = conj(f1k -tw)*Scalar(.5); + } - dst[k] = (f1k + tw) * Scalar(.5); - dst[ncfft-k] = conj(f1k -tw)*Scalar(.5); - } - - // place conjugate-symmetric half at the end for completeness - // TODO: make this configurable ( opt-out ) - for ( k=1;k < ncfft ; ++k ) - dst[nfft-k] = conj(dst[k]); - - dst[0] = dc; - dst[ncfft] = nyquist; + // place conjugate-symmetric half at the end for completeness + // TODO: make this configurable ( opt-out ) + for ( k=1;k < ncfft ; ++k ) + dst[nfft-k] = conj(dst[k]); + dst[0] = dc; + dst[ncfft] = nyquist; } - } + } - // half-complex to scalar - void inv( Scalar * dst,const Complex * src,int nfft) - { - // TODO add optimized version for even numbers - std::vector tmp(nfft); - inv(&tmp[0],src,nfft); - for (int k=0;k>1; + int ncfft2 = nfft>>2; + Complex * rtw = real_twiddles(ncfft2); + m_scratchBuf.resize(ncfft); + m_scratchBuf[0] = Complex( src[0].real() + src[ncfft].real(), src[0].real() - src[ncfft].real() ); + for (int k = 1; k <= ncfft / 2; ++k) { + Complex fk = src[k]; + Complex fnkc = conj(src[ncfft-k]); + Complex fek = fk + fnkc; + Complex tmp = fk - fnkc; + Complex fok = tmp * conj(rtw[k-1]); + m_scratchBuf[k] = fek + fok; + m_scratchBuf[ncfft-k] = conj(fek - fok); + } + scale(&m_scratchBuf[0], ncfft, Scalar(1)/nfft ); + get_plan(ncfft,true).work(0, reinterpret_cast(dst), &m_scratchBuf[0], 1,1); + } + } - typedef ei_kiss_cpx_fft PlanData; + private: - typedef std::map PlanMap; - PlanMap m_plans; - std::map > m_realTwiddles; + typedef ei_kiss_cpx_fft PlanData; - int PlanKey(int nfft,bool isinverse) const { return (nfft<<1) | isinverse; } + typedef std::map PlanMap; + PlanMap m_plans; + std::map > m_realTwiddles; + std::vector m_scratchBuf; - PlanData & get_plan(int nfft,bool inverse) - { - /* + int PlanKey(int nfft,bool isinverse) const { return (nfft<<1) | isinverse; } + + PlanData & get_plan(int nfft,bool inverse) + { + /* TODO: figure out why this does not work (g++ 4.3.2) * for some reason this does not work * - typedef typename std::map::iterator MapIt; - MapIt it; - it = m_plans.find( PlanKey(nfft,inverse) ); - if (it == m_plans.end() ) { - // create new entry - it = m_plans.insert( make_pair( PlanKey(nfft,inverse) , PlanData() ) ); - MapIt it2 = m_plans.find( PlanKey(nfft,!inverse) ); - if (it2 != m_plans.end() ) { - it->second = it2.second; - it->second.invert(); - }else{ - it->second.make_twiddles(nfft,inverse); - it->second.factorize(nfft); - } + PlanMap::iterator it; + it = m_plans.find( PlanKey(nfft,inverse) ); + if (it == m_plans.end() ) { + // create new entry + it = m_plans.insert( make_pair( PlanKey(nfft,inverse) , PlanData() ) ); + MapIt it2 = m_plans.find( PlanKey(nfft,!inverse) ); + if (it2 != m_plans.end() ) { + it->second = it2.second; + it->second.conjugate(); + }else{ + it->second.make_twiddles(nfft,inverse); + it->second.factorize(nfft); + } } return it->second; */ PlanData & pd = m_plans[ PlanKey(nfft,inverse) ]; if ( pd.m_twiddles.size() == 0 ) { - pd.make_twiddles(nfft,inverse); - pd.factorize(nfft); + pd.make_twiddles(nfft,inverse); + pd.factorize(nfft); } return pd; - } + } - Complex * real_twiddles(int ncfft2) - { + Complex * real_twiddles(int ncfft2) + { std::vector & twidref = m_realTwiddles[ncfft2];// creates new if not there if ( (int)twidref.size() != ncfft2 ) { - twidref.resize(ncfft2); - int ncfft= ncfft2<<1; - Scalar pi = acos( Scalar(-1) ); - for (int k=1;k<=ncfft2;++k) - twidref[k-1] = exp( Complex(0,-pi * ((double) (k) / ncfft + .5) ) ); + twidref.resize(ncfft2); + int ncfft= ncfft2<<1; + Scalar pi = acos( Scalar(-1) ); + for (int k=1;k<=ncfft2;++k) + twidref[k-1] = exp( Complex(0,-pi * ((double) (k) / ncfft + .5) ) ); } return &twidref[0]; - } + } - void scale(Complex *dst,int n,Scalar s) - { + void scale(Complex *dst,int n,Scalar s) + { for (int k=0;k