LCOV - code coverage report
Current view: top level - Adaptive - LMSFilter.cpp (source / functions) Hit Total Coverage
Test: coverage.info.cleaned Lines: 65 90 72.2 %
Date: 2021-02-18 20:07:22 Functions: 18 23 78.3 %

          Line data    Source code
       1             : /**
       2             :  * \file LMSFilter.cpp
       3             :  */
       4             : 
       5             : #include "LMSFilter.h"
       6             : #include <ATK/Core/TypeTraits.h>
       7             : #include <ATK/Core/Utilities.h>
       8             : 
       9             : #include <Eigen/Core>
      10             : 
      11             : #include <complex>
      12             : #include <cstdint>
      13             : #include <stdexcept>
      14             : 
      15             : namespace ATK
      16             : {
      17             :   template<typename DataType_>
      18             :   class LMSFilter<DataType_>::LMSFilterImpl
      19             :   {
      20             :   public:
      21             :     using wType = Eigen::Matrix<DataType_, Eigen::Dynamic, 1>;
      22             :     using xType = Eigen::Map<const wType>;
      23             : 
      24             :     wType w;
      25             :     /// Memory factor
      26             :     double alpha = 0.99;
      27             :     /// line search
      28             :     double mu = 0.05;
      29             : 
      30          13 :     explicit LMSFilterImpl(gsl::index size)
      31          13 :     :w(wType::Zero(size))
      32             :     {
      33          13 :     }
      34             : 
      35             :     using UpdateFunction = void (LMSFilterImpl::*)(const xType& x, DataType error);
      36             : 
      37        1200 :     void update(const xType& x, DataType error)
      38             :     {
      39        1200 :       w = static_cast<DataType>(alpha) * w + static_cast<DataType>(mu) * error * x;
      40        1200 :     }
      41             : 
      42           0 :     void update_normalized(const xType& x, DataType error)
      43             :     {
      44           0 :       w = static_cast<DataType>(alpha) * w + static_cast<DataType>(mu) * error * x / (std::numeric_limits<DataType>::epsilon() + static_cast<DataType>(x.squaredNorm()));
      45           0 :     }
      46             : 
      47           0 :     void update_signerror(const xType& x, DataType error)
      48             :     {
      49           0 :       w = static_cast<DataType>(alpha) * w + static_cast<DataType>(mu) * error / (std::numeric_limits<DataType>::epsilon() + std::abs(error)) * x;
      50           0 :     }
      51             : 
      52           0 :     void update_signdata(const xType& x, DataType error)
      53             :     {
      54           0 :       w = static_cast<DataType>(alpha) * w.array() + static_cast<DataType>(mu) * error * x.array() / (x.cwiseAbs().template cast<DataType>().array() + static_cast<DataType>(std::numeric_limits<DataType>::epsilon()));
      55           0 :     }
      56             : 
      57           0 :     void update_signsign(const xType& x, DataType error)
      58             :     {
      59           0 :       w = static_cast<DataType>(alpha) * w.array() + static_cast<DataType>(mu) * error / (std::numeric_limits<DataType>::epsilon() + std::abs(error)) * x.array() / (x.cwiseAbs().template cast<DataType>().array() + static_cast<DataType>(std::numeric_limits<DataType>::epsilon()));
      60           0 :     }
      61             : 
      62           1 :     UpdateFunction select(Mode mode)
      63             :     {
      64           1 :       switch (mode)
      65             :       {
      66           1 :       case Mode::NORMAL:
      67           1 :         return &LMSFilterImpl::update;
      68           0 :       case Mode::NORMALIZED:
      69           0 :         return &LMSFilterImpl::update_normalized;
      70           0 :       case Mode::SIGNERROR:
      71           0 :         return &LMSFilterImpl::update_signerror;
      72           0 :       case Mode::SIGNDATA:
      73           0 :         return &LMSFilterImpl::update_signdata;
      74           0 :       case Mode::SIGNSIGN:
      75           0 :         return &LMSFilterImpl::update_signsign;
      76           0 :       default:
      77           0 :           throw std::range_error("Wrong mode for LMS filter");
      78             :       }
      79             :     }
      80             :   };
      81             : 
      82             :   template<typename DataType_>
      83          12 :   LMSFilter<DataType_>::LMSFilter(gsl::index size)
      84          12 :   :Parent(2, 1), impl(std::make_unique<LMSFilterImpl>(size))
      85             :   {
      86          12 :     input_delay = size - 1;
      87          12 :   }
      88             :   
      89             :   template<typename DataType_>
      90          13 :   LMSFilter<DataType_>::~LMSFilter()
      91             :   {
      92          13 :   }
      93             : 
      94             :   template<typename DataType_>
      95           2 :   void LMSFilter<DataType_>::set_size(gsl::index size)
      96             :   {
      97           2 :     if(size == 0)
      98             :     {
      99           1 :       throw RuntimeError("Size must be strictly positive");
     100             :     }
     101             : 
     102           1 :     input_delay = size - 1;
     103           1 :     impl = std::make_unique<LMSFilterImpl>(size);
     104           1 :   }
     105             : 
     106             :   template<typename DataType_>
     107           1 :   gsl::index LMSFilter<DataType_>::get_size() const
     108             :   {
     109           1 :     return input_delay + 1;
     110             :   }
     111             :   
     112             :   template<typename DataType_>
     113           4 :   void LMSFilter<DataType_>::set_memory(double memory)
     114             :   {
     115           4 :     if (memory >= 1)
     116             :     {
     117           1 :       throw ATK::RuntimeError("Memory must be less than 1");
     118             :     }
     119           3 :     if (memory <= 0)
     120             :     {
     121           1 :       throw ATK::RuntimeError("Memory must be strictly positive");
     122             :     }
     123             : 
     124           2 :     impl->alpha = memory;
     125           2 :   }
     126             : 
     127             :   template<typename DataType_>
     128           1 :   double LMSFilter<DataType_>::get_memory() const
     129             :   {
     130           1 :     return impl->alpha;
     131             :   }
     132             : 
     133             :   template<typename DataType_>
     134           4 :   void LMSFilter<DataType_>::set_mu(double mu)
     135             :   {
     136           4 :     if (mu >= 1)
     137             :     {
     138           1 :       throw ATK::RuntimeError("Mu must be less than 1");
     139             :     }
     140           3 :     if (mu <= 0)
     141             :     {
     142           1 :       throw ATK::RuntimeError("Mu must be strictly positive");
     143             :     }
     144             : 
     145           2 :     impl->mu = mu;
     146           2 :   }
     147             : 
     148             :   template<typename DataType_>
     149           1 :   double LMSFilter<DataType_>::get_mu() const
     150             :   {
     151           1 :     return impl->mu;
     152             :   }
     153             : 
     154             :   template<typename DataType_>
     155           1 :   void LMSFilter<DataType_>::set_mode(Mode mode)
     156             :   {
     157           1 :     this->mode = mode;
     158           1 :   }
     159             : 
     160             :   template<typename DataType_>
     161           1 :   typename LMSFilter<DataType_>::Mode LMSFilter<DataType_>::get_mode() const
     162             :   {
     163           1 :     return mode;
     164             :   }
     165             : 
     166             :   template<typename DataType_>
     167           1 :   void LMSFilter<DataType_>::process_impl(gsl::index size) const
     168             :   {
     169           1 :     const DataType* ATK_RESTRICT input = converted_inputs[0];
     170           1 :     const DataType* ATK_RESTRICT ref = converted_inputs[1];
     171           1 :     DataType* ATK_RESTRICT output = outputs[0];
     172             :     
     173           1 :     auto update_function = impl->select(mode);
     174             : 
     175        1201 :     for(gsl::index i = 0; i < size; ++i)
     176             :     {
     177        1200 :       typename LMSFilterImpl::xType x(input - input_delay + i, input_delay + 1, 1);
     178        1200 :       output[i] = impl->w.conjugate().dot(x);
     179        1200 :       if(learning)
     180             :       {
     181        1200 :         (impl.get()->*update_function)(x, TypeTraits<DataType>::conj(ref[i] - output[i]));
     182             :       }
     183             :     }
     184           1 :   }
     185             : 
     186             :   template<typename DataType_>
     187           1 :   const DataType_* LMSFilter<DataType_>::get_w() const
     188             :   {
     189           1 :     return impl->w.data();
     190             :   }
     191             :   
     192             :   template<typename DataType_>
     193           0 :   void LMSFilter<DataType_>::set_w(gsl::not_null<const DataType_*> w)
     194             :   {
     195           0 :     impl->w = Eigen::Map<const typename LMSFilterImpl::wType>(w.get(), get_size());
     196           0 :   }
     197             : 
     198             :   template<typename DataType_>
     199           1 :   void LMSFilter<DataType_>::set_learning(bool learning)
     200             :   {
     201           1 :     this->learning = learning;
     202           1 :   }
     203             : 
     204             :   template<typename DataType_>
     205           2 :   bool LMSFilter<DataType_>::get_learning() const
     206             :   {
     207           2 :     return learning;
     208             :   }
     209             : 
     210             :   template class LMSFilter<double>;
     211             : #if ATK_ENABLE_INSTANTIATION
     212             :   template class LMSFilter<float>;
     213             :   template class LMSFilter<std::complex<float>>;
     214             :   template class LMSFilter<std::complex<double>>;
     215             : #endif
     216             : }

Generated by: LCOV version TK-3.3.0-4-gdba42eea