LCOV - code coverage report
Current view: top level - Adaptive - BlockLMSFilter.cpp (source / functions) Hit Total Coverage
Test: coverage.info.cleaned Lines: 93 96 96.9 %
Date: 2021-02-18 20:07:22 Functions: 18 19 94.7 %

          Line data    Source code
       1             : /**
       2             :  * \file BlockLMSFilter.cpp
       3             :  */
       4             : 
       5             : #include "BlockLMSFilter.h"
       6             : #include <ATK/Core/TypeTraits.h>
       7             : #include <ATK/Core/Utilities.h>
       8             : #include <ATK/Utility/FFT.h>
       9             : 
      10             : #include <Eigen/Core>
      11             : 
      12             : #include <complex>
      13             : #include <cstdint>
      14             : #include <stdexcept>
      15             : 
      16             : namespace ATK
      17             : {
      18             :   template<typename DataType_>
      19             :   class BlockLMSFilter<DataType_>::BlockLMSFilterImpl
      20             :   {
      21             :   public:
      22             :     using cwType = Eigen::Matrix<std::complex<double>, Eigen::Dynamic, 1>;
      23             :     using cxType = Eigen::Map<const cwType>;
      24             :     using wType = Eigen::Matrix<DataType_, Eigen::Dynamic, 1>;
      25             :     using xType = Eigen::Map<const wType>;
      26             : 
      27             :     /// FFT of the current coefficients
      28             :     cwType wfft;
      29             :     /// Current accumulated input
      30             :     std::vector<DataType_> block_input;
      31             :     /// Current accumulated ref
      32             :     std::vector<DataType_> block_ref;
      33             :     /// Current accumulated error
      34             :     std::vector<DataType_> block_error;
      35             : 
      36             :     /// Temporary storage
      37             :     std::vector<std::complex<double> > block_fft;
      38             :     /// Temporary storage
      39             :     std::vector<std::complex<double> > block_fft2;
      40             :     /// Temporary storage
      41             :     std::vector<DataType_> block_ifft;
      42             : 
      43             :     FFT<double> fft;
      44             :     /// Memory factor
      45             :     double alpha = .99;
      46             :     /// line search
      47             :     double mu = 0.05;
      48             :     /// block size
      49             :     gsl::index block_size = 0;
      50             :     gsl::index accumulate_block_size = 0;
      51             :     bool learning = true;
      52             : 
      53          15 :     explicit BlockLMSFilterImpl(gsl::index size)
      54          15 :     :wfft(cwType::Zero(2*size)), block_input(2 * size, DataType_(0)), block_ref(size, DataType_(0)), block_error(size, DataType_(0)),
      55          15 :      block_fft(2 * size), block_fft2(2 * size), block_ifft(2 * size), block_size(size)
      56             :     {
      57          15 :       fft.set_size(2 * size);
      58          15 :     }
      59             : 
      60        1300 :     void apply_update()
      61             :     {
      62        1300 :       ++accumulate_block_size;
      63        1300 :       if (accumulate_block_size == block_size)
      64             :       {
      65          13 :         fft.process_forward(block_input.data(), block_fft2.data(), block_size * 2);
      66        2613 :         for(gsl::index i = 0; i < 2 * block_size; ++i)
      67             :         {
      68        2600 :           block_fft[i] = block_fft2[i] * wfft(i, 0) * std::complex<double>(block_size * 2); // Diagonal U * FFT factor
      69             :         }
      70          13 :         fft.process_backward(block_fft.data(), block_ifft.data(), block_size * 2);
      71        1313 :         for (gsl::index i = 0; i < block_size; ++i)
      72             :         {
      73        1300 :           block_ifft[block_size + i] = block_ref[i] - block_ifft[block_size + i]; // error on last elements of Y
      74        1300 :           block_error[i] = block_ifft[block_size + i];
      75             :         }
      76          13 :         if (learning)
      77             :         {
      78          13 :           std::fill(block_ifft.begin(), block_ifft.begin() + block_size, 0);
      79          13 :           fft.process_forward(block_ifft.data(), block_fft.data(), block_size * 2); // FFT of the error stored in ifft
      80        2613 :           for (gsl::index i = 0; i < 2 * block_size; ++i)
      81             :           {
      82        2600 :             block_fft[i] = std::conj(block_fft2[i]) * block_fft[i] * std::complex<double>(block_size * 2); // diagonal * FFT factor
      83             :           }
      84          13 :           fft.process_backward(block_fft.data(), block_ifft.data(), 2 * block_size);
      85          13 :           fft.process_forward(block_ifft.data(), block_fft.data(), block_size);
      86          13 :           wfft = alpha * wfft + static_cast<std::complex<double>>(mu) * cxType(block_fft.data(), 2 * block_size);
      87             :         }
      88             : 
      89          13 :         accumulate_block_size = 0;
      90          13 :         std::memcpy(&block_input[0], &block_input[block_size], block_size * sizeof(DataType_));
      91             :       }
      92        1300 :     }
      93             : 
      94        1300 :     void update(DataType input, DataType ref, DataType& error)
      95             :     {
      96        1300 :       block_input[block_size + accumulate_block_size] = input;
      97        1300 :       error = block_ref[accumulate_block_size] - block_error[accumulate_block_size];
      98        1300 :       block_ref[accumulate_block_size] = ref;
      99             : 
     100        1300 :       apply_update();
     101        1300 :     }
     102             :   };
     103             : 
     104             :   template<typename DataType_>
     105          14 :   BlockLMSFilter<DataType_>::BlockLMSFilter(gsl::index size)
     106          16 :   :Parent(2, 1), impl(std::make_unique<BlockLMSFilterImpl>(size))
     107             :   {
     108          14 :     if (size == 0)
     109             :     {
     110           1 :       throw RuntimeError("Size must be strictly positive");
     111             :     }
     112          13 :   }
     113             :   
     114             :   template<typename DataType_>
     115          14 :   BlockLMSFilter<DataType_>::~BlockLMSFilter()
     116             :   {
     117          14 :   }
     118             :   
     119             :   template<typename DataType_>
     120           2 :   void BlockLMSFilter<DataType_>::set_size(gsl::index size)
     121             :   {
     122           2 :     if(size == 0)
     123             :     {
     124           1 :       throw RuntimeError("Size must be strictly positive");
     125             :     }
     126           1 :     auto block_size = impl->block_size;
     127           1 :     impl = std::make_unique<BlockLMSFilterImpl>(size);
     128           1 :     set_block_size(block_size);
     129           1 :   }
     130             : 
     131             :   template<typename DataType_>
     132           1 :   gsl::index BlockLMSFilter<DataType_>::get_size() const
     133             :   {
     134           1 :     return impl->wfft.size() / 2;
     135             :   }
     136             :   
     137             :   template<typename DataType_>
     138           3 :   void BlockLMSFilter<DataType_>::set_block_size(gsl::index size)
     139             :   {
     140           3 :     if (size == 0)
     141             :     {
     142           1 :       throw ATK::RuntimeError("Block size must be strictly positive");
     143             :     }
     144           2 :     impl->accumulate_block_size = 0;
     145           2 :     impl->block_size = size;
     146           2 :     impl->block_input.assign(2 * size, 0);
     147           2 :     impl->block_ref.assign(size, 0);
     148           2 :     impl->block_fft.assign(2 * size, 0);
     149           2 :     impl->block_fft2.assign(2 * size, 0);
     150           2 :     impl->block_ifft.assign(2 * size, 0);
     151           2 :   }
     152             : 
     153             :   template<typename DataType_>
     154           1 :   gsl::index BlockLMSFilter<DataType_>::get_block_size() const
     155             :   {
     156           1 :     return impl->block_size;
     157             :   }
     158             : 
     159             :   template<typename DataType_>
     160           4 :   void BlockLMSFilter<DataType_>::set_memory(double memory)
     161             :   {
     162           4 :     if (memory >= 1)
     163             :     {
     164           1 :       throw ATK::RuntimeError("Memory must be less than 1");
     165             :     }
     166           3 :     if (memory <= 0)
     167             :     {
     168           1 :       throw ATK::RuntimeError("Memory must be strictly positive");
     169             :     }
     170             : 
     171           2 :     impl->alpha = memory;
     172           2 :   }
     173             : 
     174             :   template<typename DataType_>
     175           1 :   double BlockLMSFilter<DataType_>::get_memory() const
     176             :   {
     177           1 :     return impl->alpha;
     178             :   }
     179             : 
     180             :   template<typename DataType_>
     181           4 :   void BlockLMSFilter<DataType_>::set_mu(double mu)
     182             :   {
     183           4 :     if (mu >= 1)
     184             :     {
     185           1 :       throw ATK::RuntimeError("Mu must be less than 1");
     186             :     }
     187           3 :     if (mu <= 0)
     188             :     {
     189           1 :       throw ATK::RuntimeError("Mu must be strictly positive");
     190             :     }
     191             : 
     192           2 :     impl->mu = mu;
     193           2 :   }
     194             : 
     195             :   template<typename DataType_>
     196           1 :   double BlockLMSFilter<DataType_>::get_mu() const
     197             :   {
     198           1 :     return impl->mu;
     199             :   }
     200             : 
     201             :   template<typename DataType_>
     202           2 :   void BlockLMSFilter<DataType_>::process_impl(gsl::index size) const
     203             :   {
     204           2 :     const DataType* ATK_RESTRICT input = converted_inputs[0];
     205           2 :     const DataType* ATK_RESTRICT ref = converted_inputs[1];
     206           2 :     DataType* ATK_RESTRICT output = outputs[0];
     207             :     
     208        1302 :     for(gsl::index i = 0; i < size; ++i)
     209             :     {
     210        1300 :       impl->update(input[i], ref[i], output[i]);
     211             :     }
     212           2 :   }
     213             : 
     214             :   template<typename DataType_>
     215           1 :   const std::complex<double>* BlockLMSFilter<DataType_>::get_w() const
     216             :   {
     217           1 :     return impl->wfft.data();
     218             :   }
     219             :   
     220             :   template<typename DataType_>
     221           0 :   void BlockLMSFilter<DataType_>::set_w(gsl::not_null<const std::complex<double>*> w)
     222             :   {
     223           0 :     impl->wfft = Eigen::Map<const typename BlockLMSFilterImpl::cwType>(w.get(), get_size() * 2);
     224           0 :   }
     225             : 
     226             :   template<typename DataType_>
     227           1 :   void BlockLMSFilter<DataType_>::set_learning(bool learning)
     228             :   {
     229           1 :     impl->learning = learning;
     230           1 :   }
     231             : 
     232             :   template<typename DataType_>
     233           2 :   bool BlockLMSFilter<DataType_>::get_learning() const
     234             :   {
     235           2 :     return impl->learning;
     236             :   }
     237             : 
     238             :   template class BlockLMSFilter<double>;
     239             : #if ATK_ENABLE_INSTANTIATION
     240             :   template class BlockLMSFilter<std::complex<double>>;
     241             : #endif
     242             : }

Generated by: LCOV version TK-3.3.0-4-gdba42eea