Line data Source code
1 : /** 2 : * \file LMSFilter.cpp 3 : */ 4 : 5 : #include "LMSFilter.h" 6 : #include <ATK/Core/TypeTraits.h> 7 : #include <ATK/Core/Utilities.h> 8 : 9 : #include <Eigen/Core> 10 : 11 : #include <complex> 12 : #include <cstdint> 13 : #include <stdexcept> 14 : 15 : namespace ATK 16 : { 17 : template<typename DataType_> 18 : class LMSFilter<DataType_>::LMSFilterImpl 19 : { 20 : public: 21 : using wType = Eigen::Matrix<DataType_, Eigen::Dynamic, 1>; 22 : using xType = Eigen::Map<const wType>; 23 : 24 : wType w; 25 : /// Memory factor 26 : double alpha = 0.99; 27 : /// line search 28 : double mu = 0.05; 29 : 30 13 : explicit LMSFilterImpl(gsl::index size) 31 13 : :w(wType::Zero(size)) 32 : { 33 13 : } 34 : 35 : using UpdateFunction = void (LMSFilterImpl::*)(const xType& x, DataType error); 36 : 37 1200 : void update(const xType& x, DataType error) 38 : { 39 1200 : w = static_cast<DataType>(alpha) * w + static_cast<DataType>(mu) * error * x; 40 1200 : } 41 : 42 0 : void update_normalized(const xType& x, DataType error) 43 : { 44 0 : w = static_cast<DataType>(alpha) * w + static_cast<DataType>(mu) * error * x / (std::numeric_limits<DataType>::epsilon() + static_cast<DataType>(x.squaredNorm())); 45 0 : } 46 : 47 0 : void update_signerror(const xType& x, DataType error) 48 : { 49 0 : w = static_cast<DataType>(alpha) * w + static_cast<DataType>(mu) * error / (std::numeric_limits<DataType>::epsilon() + std::abs(error)) * x; 50 0 : } 51 : 52 0 : void update_signdata(const xType& x, DataType error) 53 : { 54 0 : w = static_cast<DataType>(alpha) * w.array() + static_cast<DataType>(mu) * error * x.array() / (x.cwiseAbs().template cast<DataType>().array() + static_cast<DataType>(std::numeric_limits<DataType>::epsilon())); 55 0 : } 56 : 57 0 : void update_signsign(const xType& x, DataType error) 58 : { 59 0 : w = static_cast<DataType>(alpha) * w.array() + static_cast<DataType>(mu) * error / (std::numeric_limits<DataType>::epsilon() + std::abs(error)) * x.array() / (x.cwiseAbs().template cast<DataType>().array() + static_cast<DataType>(std::numeric_limits<DataType>::epsilon())); 60 0 : } 61 : 62 1 : UpdateFunction select(Mode mode) 63 : { 64 1 : switch (mode) 65 : { 66 1 : case Mode::NORMAL: 67 1 : return &LMSFilterImpl::update; 68 0 : case Mode::NORMALIZED: 69 0 : return &LMSFilterImpl::update_normalized; 70 0 : case Mode::SIGNERROR: 71 0 : return &LMSFilterImpl::update_signerror; 72 0 : case Mode::SIGNDATA: 73 0 : return &LMSFilterImpl::update_signdata; 74 0 : case Mode::SIGNSIGN: 75 0 : return &LMSFilterImpl::update_signsign; 76 0 : default: 77 0 : throw std::range_error("Wrong mode for LMS filter"); 78 : } 79 : } 80 : }; 81 : 82 : template<typename DataType_> 83 12 : LMSFilter<DataType_>::LMSFilter(gsl::index size) 84 12 : :Parent(2, 1), impl(std::make_unique<LMSFilterImpl>(size)) 85 : { 86 12 : input_delay = size - 1; 87 12 : } 88 : 89 : template<typename DataType_> 90 13 : LMSFilter<DataType_>::~LMSFilter() 91 : { 92 13 : } 93 : 94 : template<typename DataType_> 95 2 : void LMSFilter<DataType_>::set_size(gsl::index size) 96 : { 97 2 : if(size == 0) 98 : { 99 1 : throw RuntimeError("Size must be strictly positive"); 100 : } 101 : 102 1 : input_delay = size - 1; 103 1 : impl = std::make_unique<LMSFilterImpl>(size); 104 1 : } 105 : 106 : template<typename DataType_> 107 1 : gsl::index LMSFilter<DataType_>::get_size() const 108 : { 109 1 : return input_delay + 1; 110 : } 111 : 112 : template<typename DataType_> 113 4 : void LMSFilter<DataType_>::set_memory(double memory) 114 : { 115 4 : if (memory >= 1) 116 : { 117 1 : throw ATK::RuntimeError("Memory must be less than 1"); 118 : } 119 3 : if (memory <= 0) 120 : { 121 1 : throw ATK::RuntimeError("Memory must be strictly positive"); 122 : } 123 : 124 2 : impl->alpha = memory; 125 2 : } 126 : 127 : template<typename DataType_> 128 1 : double LMSFilter<DataType_>::get_memory() const 129 : { 130 1 : return impl->alpha; 131 : } 132 : 133 : template<typename DataType_> 134 4 : void LMSFilter<DataType_>::set_mu(double mu) 135 : { 136 4 : if (mu >= 1) 137 : { 138 1 : throw ATK::RuntimeError("Mu must be less than 1"); 139 : } 140 3 : if (mu <= 0) 141 : { 142 1 : throw ATK::RuntimeError("Mu must be strictly positive"); 143 : } 144 : 145 2 : impl->mu = mu; 146 2 : } 147 : 148 : template<typename DataType_> 149 1 : double LMSFilter<DataType_>::get_mu() const 150 : { 151 1 : return impl->mu; 152 : } 153 : 154 : template<typename DataType_> 155 1 : void LMSFilter<DataType_>::set_mode(Mode mode) 156 : { 157 1 : this->mode = mode; 158 1 : } 159 : 160 : template<typename DataType_> 161 1 : typename LMSFilter<DataType_>::Mode LMSFilter<DataType_>::get_mode() const 162 : { 163 1 : return mode; 164 : } 165 : 166 : template<typename DataType_> 167 1 : void LMSFilter<DataType_>::process_impl(gsl::index size) const 168 : { 169 1 : const DataType* ATK_RESTRICT input = converted_inputs[0]; 170 1 : const DataType* ATK_RESTRICT ref = converted_inputs[1]; 171 1 : DataType* ATK_RESTRICT output = outputs[0]; 172 : 173 1 : auto update_function = impl->select(mode); 174 : 175 1201 : for(gsl::index i = 0; i < size; ++i) 176 : { 177 1200 : typename LMSFilterImpl::xType x(input - input_delay + i, input_delay + 1, 1); 178 1200 : output[i] = impl->w.conjugate().dot(x); 179 1200 : if(learning) 180 : { 181 1200 : (impl.get()->*update_function)(x, TypeTraits<DataType>::conj(ref[i] - output[i])); 182 : } 183 : } 184 1 : } 185 : 186 : template<typename DataType_> 187 1 : const DataType_* LMSFilter<DataType_>::get_w() const 188 : { 189 1 : return impl->w.data(); 190 : } 191 : 192 : template<typename DataType_> 193 0 : void LMSFilter<DataType_>::set_w(gsl::not_null<const DataType_*> w) 194 : { 195 0 : impl->w = Eigen::Map<const typename LMSFilterImpl::wType>(w.get(), get_size()); 196 0 : } 197 : 198 : template<typename DataType_> 199 1 : void LMSFilter<DataType_>::set_learning(bool learning) 200 : { 201 1 : this->learning = learning; 202 1 : } 203 : 204 : template<typename DataType_> 205 2 : bool LMSFilter<DataType_>::get_learning() const 206 : { 207 2 : return learning; 208 : } 209 : 210 : template class LMSFilter<double>; 211 : #if ATK_ENABLE_INSTANTIATION 212 : template class LMSFilter<float>; 213 : template class LMSFilter<std::complex<float>>; 214 : template class LMSFilter<std::complex<double>>; 215 : #endif 216 : }