Line data Source code
1 : /** 2 : * \file RLSFilter.cpp 3 : */ 4 : 5 : #include "RLSFilter.h" 6 : #include <ATK/Core/TypeTraits.h> 7 : #include <ATK/Core/Utilities.h> 8 : 9 : #include <Eigen/Core> 10 : 11 : #include <cstdint> 12 : #include <complex> 13 : #include <stdexcept> 14 : 15 : namespace ATK 16 : { 17 : template<typename DataType_> 18 : class RLSFilter<DataType_>::RLSFilterImpl 19 : { 20 : public: 21 : using PType = Eigen::Matrix<DataType_, Eigen::Dynamic, Eigen::Dynamic>; 22 : using wType = Eigen::Matrix<DataType_, Eigen::Dynamic, 1>; 23 : using xType = Eigen::Map<const wType>; 24 : 25 11 : explicit RLSFilterImpl(gsl::index size) 26 11 : :P(PType::Identity(size, size) / DataType(size)), w(wType::Zero(size)) 27 : { 28 11 : } 29 : 30 66536 : void learn(const xType& x, DataType_ target, DataType_ actual) 31 : { 32 66536 : auto alpha = target - actual; 33 66536 : auto xreverse = x.reverse(); 34 : 35 133072 : wType g = (P * xreverse) / ((xreverse.adjoint() * P * xreverse)(0,0) + static_cast<DataType>(memory)); 36 66536 : PType pupdate = (g * (xreverse.adjoint() * P)); 37 66536 : w = w + TypeTraits<DataType>::conj(alpha) * g; 38 66536 : P = (P - (pupdate + pupdate.transpose()) / 2) * memory; 39 66536 : } 40 : 41 0 : void set_P(const DataType_* P) 42 : { 43 0 : this->P = Eigen::Map<const PType>(P, this->P.rows(), this->P.cols()); 44 0 : } 45 : 46 0 : const DataType_* get_P() const 47 : { 48 0 : return P.data(); 49 : } 50 : 51 1 : void set_w(const DataType_* w) 52 : { 53 1 : this->w = xType(w, this->w.rows(), 1); 54 1 : } 55 : 56 1 : const DataType_* get_w() const 57 : { 58 1 : return w.data(); 59 : } 60 : 61 : PType P; 62 : wType w; 63 : double memory = 0.99; 64 : }; 65 : 66 : template<typename DataType_> 67 11 : RLSFilter<DataType_>::RLSFilter(gsl::index size) 68 11 : :Parent(1, 1), impl(std::make_unique<RLSFilterImpl>(size)), global_size(size) 69 : { 70 11 : input_delay = size + 1; 71 11 : } 72 : 73 : template<typename DataType_> 74 12 : RLSFilter<DataType_>::~RLSFilter() 75 : { 76 12 : } 77 : 78 : template<typename DataType_> 79 2 : void RLSFilter<DataType_>::set_size(gsl::index size) 80 : { 81 2 : if(size == 0) 82 : { 83 1 : throw ATK::RuntimeError("Size must be strictly positive"); 84 : } 85 : 86 1 : impl->P = RLSFilterImpl::PType::Identity(size, size) / DataType(size); 87 1 : impl->w = typename RLSFilterImpl::wType(size, 1); 88 1 : input_delay = size+1; 89 1 : this->global_size = size; 90 1 : } 91 : 92 : template<typename DataType_> 93 1 : gsl::index RLSFilter<DataType_>::get_size() const 94 : { 95 1 : return global_size; 96 : } 97 : 98 : template<typename DataType_> 99 7 : void RLSFilter<DataType_>::set_memory(double memory) 100 : { 101 7 : if(memory >= 1) 102 : { 103 1 : throw ATK::RuntimeError("Memory must be less than 1"); 104 : } 105 6 : if(memory <= 0) 106 : { 107 1 : throw ATK::RuntimeError("Memory must be strictly positive"); 108 : } 109 : 110 5 : impl->memory = memory; 111 5 : } 112 : 113 : template<typename DataType_> 114 1 : double RLSFilter<DataType_>::get_memory() const 115 : { 116 1 : return impl->memory; 117 : } 118 : 119 : template<typename DataType_> 120 6 : void RLSFilter<DataType_>::set_learning(bool learning) 121 : { 122 6 : this->learning = learning; 123 6 : } 124 : 125 : template<typename DataType_> 126 2 : bool RLSFilter<DataType_>::get_learning() const 127 : { 128 2 : return learning; 129 : } 130 : 131 : template<typename DataType_> 132 6 : void RLSFilter<DataType_>::process_impl(gsl::index size) const 133 : { 134 6 : const DataType* ATK_RESTRICT input = converted_inputs[0]; 135 6 : DataType* ATK_RESTRICT output = outputs[0]; 136 : 137 262151 : for(gsl::index i = 0; i < size; ++i) 138 : { 139 262145 : typename RLSFilterImpl::xType x(input - global_size + i, global_size, 1); 140 : 141 : // compute next sample 142 262145 : output[i] = impl->w.adjoint().dot(x.reverse()); 143 : 144 262145 : if(learning) 145 : { 146 : //update w and P 147 66536 : impl->learn(x, input[i], output[i]); 148 : } 149 : } 150 6 : } 151 : 152 : template<typename DataType_> 153 0 : void RLSFilter<DataType_>::set_P(const DataType_* P) 154 : { 155 0 : impl->set_P(P); 156 0 : } 157 : 158 : template<typename DataType_> 159 0 : const DataType_* RLSFilter<DataType_>::get_P() const 160 : { 161 0 : return impl->get_P(); 162 : } 163 : 164 : template<typename DataType_> 165 1 : void RLSFilter<DataType_>::set_w(const DataType_* w) 166 : { 167 1 : impl->set_w(w); 168 1 : } 169 : 170 : template<typename DataType_> 171 1 : const DataType_* RLSFilter<DataType_>::get_w() const 172 : { 173 1 : return impl->get_w(); 174 : } 175 : 176 : template class RLSFilter<double>; 177 : #if ATK_ENABLE_INSTANTIATION 178 : template class RLSFilter<float>; 179 : template class RLSFilter<std::complex<float>>; 180 : template class RLSFilter<std::complex<double>>; 181 : #endif 182 : }