LCOV - code coverage report
Current view: top level - Adaptive - RLSFilter.cpp (source / functions) Hit Total Coverage
Test: coverage.info.cleaned Lines: 60 70 85.7 %
Date: 2021-02-18 20:07:22 Functions: 16 20 80.0 %

          Line data    Source code
       1             : /**
       2             :  * \file RLSFilter.cpp
       3             :  */
       4             : 
       5             : #include "RLSFilter.h"
       6             : #include <ATK/Core/TypeTraits.h>
       7             : #include <ATK/Core/Utilities.h>
       8             : 
       9             : #include <Eigen/Core>
      10             : 
      11             : #include <cstdint>
      12             : #include <complex>
      13             : #include <stdexcept>
      14             : 
      15             : namespace ATK
      16             : {
      17             :   template<typename DataType_>
      18             :   class RLSFilter<DataType_>::RLSFilterImpl
      19             :   {
      20             :   public:
      21             :     using PType = Eigen::Matrix<DataType_, Eigen::Dynamic, Eigen::Dynamic>;
      22             :     using wType = Eigen::Matrix<DataType_, Eigen::Dynamic, 1>;
      23             :     using xType = Eigen::Map<const wType>;
      24             : 
      25          11 :     explicit RLSFilterImpl(gsl::index size)
      26          11 :       :P(PType::Identity(size, size) / DataType(size)), w(wType::Zero(size))
      27             :     {
      28          11 :     }
      29             : 
      30       66536 :     void learn(const xType& x, DataType_ target, DataType_ actual)
      31             :     {
      32       66536 :       auto alpha = target - actual;
      33       66536 :       auto xreverse = x.reverse();
      34             : 
      35      133072 :       wType g = (P * xreverse) / ((xreverse.adjoint() * P * xreverse)(0,0) + static_cast<DataType>(memory));
      36       66536 :       PType pupdate = (g * (xreverse.adjoint() * P));
      37       66536 :       w = w + TypeTraits<DataType>::conj(alpha) * g;
      38       66536 :       P = (P - (pupdate + pupdate.transpose()) / 2) * memory;
      39       66536 :     }
      40             : 
      41           0 :     void set_P(const DataType_* P)
      42             :     {
      43           0 :       this->P = Eigen::Map<const PType>(P, this->P.rows(), this->P.cols());
      44           0 :     }
      45             : 
      46           0 :     const DataType_* get_P() const
      47             :     {
      48           0 :       return P.data();
      49             :     }
      50             : 
      51           1 :     void set_w(const DataType_* w)
      52             :     {
      53           1 :       this->w = xType(w, this->w.rows(), 1);
      54           1 :     }
      55             : 
      56           1 :     const DataType_* get_w() const
      57             :     {
      58           1 :       return w.data();
      59             :     }
      60             : 
      61             :     PType P;
      62             :     wType w;
      63             :     double memory = 0.99;
      64             :   };
      65             : 
      66             :   template<typename DataType_>
      67          11 :   RLSFilter<DataType_>::RLSFilter(gsl::index size)
      68          11 :   :Parent(1, 1), impl(std::make_unique<RLSFilterImpl>(size)), global_size(size)
      69             :   {
      70          11 :     input_delay = size + 1;
      71          11 :   }
      72             :   
      73             :   template<typename DataType_>
      74          12 :   RLSFilter<DataType_>::~RLSFilter()
      75             :   {
      76          12 :   }
      77             :   
      78             :   template<typename DataType_>
      79           2 :   void RLSFilter<DataType_>::set_size(gsl::index size)
      80             :   {
      81           2 :     if(size == 0)
      82             :     {
      83           1 :       throw ATK::RuntimeError("Size must be strictly positive");
      84             :     }
      85             : 
      86           1 :     impl->P = RLSFilterImpl::PType::Identity(size, size) / DataType(size);
      87           1 :     impl->w = typename RLSFilterImpl::wType(size, 1);
      88           1 :     input_delay = size+1;
      89           1 :     this->global_size = size;
      90           1 :   }
      91             : 
      92             :   template<typename DataType_>
      93           1 :   gsl::index RLSFilter<DataType_>::get_size() const
      94             :   {
      95           1 :     return global_size;
      96             :   }
      97             :   
      98             :   template<typename DataType_>
      99           7 :   void RLSFilter<DataType_>::set_memory(double memory)
     100             :   {
     101           7 :     if(memory >= 1)
     102             :     {
     103           1 :       throw ATK::RuntimeError("Memory must be less than 1");
     104             :     }
     105           6 :     if(memory <= 0)
     106             :     {
     107           1 :       throw ATK::RuntimeError("Memory must be strictly positive");
     108             :     }
     109             :     
     110           5 :     impl->memory = memory;
     111           5 :   }
     112             :   
     113             :   template<typename DataType_>
     114           1 :   double RLSFilter<DataType_>::get_memory() const
     115             :   {
     116           1 :     return impl->memory;
     117             :   }
     118             :   
     119             :   template<typename DataType_>
     120           6 :   void RLSFilter<DataType_>::set_learning(bool learning)
     121             :   {
     122           6 :     this->learning = learning;
     123           6 :   }
     124             :   
     125             :   template<typename DataType_>
     126           2 :   bool RLSFilter<DataType_>::get_learning() const
     127             :   {
     128           2 :     return learning;
     129             :   }
     130             : 
     131             :   template<typename DataType_>
     132           6 :   void RLSFilter<DataType_>::process_impl(gsl::index size) const
     133             :   {
     134           6 :     const DataType* ATK_RESTRICT input = converted_inputs[0];
     135           6 :     DataType* ATK_RESTRICT output = outputs[0];
     136             :     
     137      262151 :     for(gsl::index i = 0; i < size; ++i)
     138             :     {
     139      262145 :       typename RLSFilterImpl::xType x(input - global_size + i, global_size, 1);
     140             :       
     141             :       // compute next sample
     142      262145 :       output[i] = impl->w.adjoint().dot(x.reverse());
     143             :       
     144      262145 :       if(learning)
     145             :       {
     146             :         //update w and P
     147       66536 :         impl->learn(x, input[i], output[i]);
     148             :       }
     149             :     }
     150           6 :   }
     151             : 
     152             :   template<typename DataType_>
     153           0 :   void RLSFilter<DataType_>::set_P(const DataType_* P)
     154             :   {
     155           0 :      impl->set_P(P);
     156           0 :   }
     157             :   
     158             :   template<typename DataType_>
     159           0 :   const DataType_* RLSFilter<DataType_>::get_P() const
     160             :   {
     161           0 :     return impl->get_P();
     162             :   }
     163             :   
     164             :   template<typename DataType_>
     165           1 :   void RLSFilter<DataType_>::set_w(const DataType_* w)
     166             :   {
     167           1 :     impl->set_w(w);
     168           1 :   }
     169             :   
     170             :   template<typename DataType_>
     171           1 :   const DataType_* RLSFilter<DataType_>::get_w() const
     172             :   {
     173           1 :     return impl->get_w();
     174             :   }
     175             : 
     176             :   template class RLSFilter<double>;
     177             : #if ATK_ENABLE_INSTANTIATION
     178             :   template class RLSFilter<float>;
     179             :   template class RLSFilter<std::complex<float>>;
     180             :   template class RLSFilter<std::complex<double>>;
     181             : #endif
     182             : }

Generated by: LCOV version TK-3.3.0-4-gdba42eea