Line data Source code
1 : /**
2 : * \file BlockLMSFilter.cpp
3 : */
4 :
5 : #include "BlockLMSFilter.h"
6 : #include <ATK/Core/TypeTraits.h>
7 : #include <ATK/Core/Utilities.h>
8 : #include <ATK/Utility/FFT.h>
9 :
10 : #include <Eigen/Core>
11 :
12 : #include <complex>
13 : #include <cstdint>
14 : #include <stdexcept>
15 :
16 : namespace ATK
17 : {
18 : template<typename DataType_>
19 : class BlockLMSFilter<DataType_>::BlockLMSFilterImpl
20 : {
21 : public:
22 : using cwType = Eigen::Matrix<std::complex<double>, Eigen::Dynamic, 1>;
23 : using cxType = Eigen::Map<const cwType>;
24 : using wType = Eigen::Matrix<DataType_, Eigen::Dynamic, 1>;
25 : using xType = Eigen::Map<const wType>;
26 :
27 : /// FFT of the current coefficients
28 : cwType wfft;
29 : /// Current accumulated input
30 : std::vector<DataType_> block_input;
31 : /// Current accumulated ref
32 : std::vector<DataType_> block_ref;
33 : /// Current accumulated error
34 : std::vector<DataType_> block_error;
35 :
36 : /// Temporary storage
37 : std::vector<std::complex<double> > block_fft;
38 : /// Temporary storage
39 : std::vector<std::complex<double> > block_fft2;
40 : /// Temporary storage
41 : std::vector<DataType_> block_ifft;
42 :
43 : FFT<double> fft;
44 : /// Memory factor
45 : double alpha = .99;
46 : /// line search
47 : double mu = 0.05;
48 : /// block size
49 : gsl::index block_size = 0;
50 : gsl::index accumulate_block_size = 0;
51 : bool learning = true;
52 :
53 15 : explicit BlockLMSFilterImpl(gsl::index size)
54 15 : :wfft(cwType::Zero(2*size)), block_input(2 * size, DataType_(0)), block_ref(size, DataType_(0)), block_error(size, DataType_(0)),
55 15 : block_fft(2 * size), block_fft2(2 * size), block_ifft(2 * size), block_size(size)
56 : {
57 15 : fft.set_size(2 * size);
58 15 : }
59 :
60 1300 : void apply_update()
61 : {
62 1300 : ++accumulate_block_size;
63 1300 : if (accumulate_block_size == block_size)
64 : {
65 13 : fft.process_forward(block_input.data(), block_fft2.data(), block_size * 2);
66 2613 : for(gsl::index i = 0; i < 2 * block_size; ++i)
67 : {
68 2600 : block_fft[i] = block_fft2[i] * wfft(i, 0) * std::complex<double>(block_size * 2); // Diagonal U * FFT factor
69 : }
70 13 : fft.process_backward(block_fft.data(), block_ifft.data(), block_size * 2);
71 1313 : for (gsl::index i = 0; i < block_size; ++i)
72 : {
73 1300 : block_ifft[block_size + i] = block_ref[i] - block_ifft[block_size + i]; // error on last elements of Y
74 1300 : block_error[i] = block_ifft[block_size + i];
75 : }
76 13 : if (learning)
77 : {
78 13 : std::fill(block_ifft.begin(), block_ifft.begin() + block_size, 0);
79 13 : fft.process_forward(block_ifft.data(), block_fft.data(), block_size * 2); // FFT of the error stored in ifft
80 2613 : for (gsl::index i = 0; i < 2 * block_size; ++i)
81 : {
82 2600 : block_fft[i] = std::conj(block_fft2[i]) * block_fft[i] * std::complex<double>(block_size * 2); // diagonal * FFT factor
83 : }
84 13 : fft.process_backward(block_fft.data(), block_ifft.data(), 2 * block_size);
85 13 : fft.process_forward(block_ifft.data(), block_fft.data(), block_size);
86 13 : wfft = alpha * wfft + static_cast<std::complex<double>>(mu) * cxType(block_fft.data(), 2 * block_size);
87 : }
88 :
89 13 : accumulate_block_size = 0;
90 13 : std::memcpy(&block_input[0], &block_input[block_size], block_size * sizeof(DataType_));
91 : }
92 1300 : }
93 :
94 1300 : void update(DataType input, DataType ref, DataType& error)
95 : {
96 1300 : block_input[block_size + accumulate_block_size] = input;
97 1300 : error = block_ref[accumulate_block_size] - block_error[accumulate_block_size];
98 1300 : block_ref[accumulate_block_size] = ref;
99 :
100 1300 : apply_update();
101 1300 : }
102 : };
103 :
104 : template<typename DataType_>
105 14 : BlockLMSFilter<DataType_>::BlockLMSFilter(gsl::index size)
106 16 : :Parent(2, 1), impl(std::make_unique<BlockLMSFilterImpl>(size))
107 : {
108 14 : if (size == 0)
109 : {
110 1 : throw RuntimeError("Size must be strictly positive");
111 : }
112 13 : }
113 :
114 : template<typename DataType_>
115 14 : BlockLMSFilter<DataType_>::~BlockLMSFilter()
116 : {
117 14 : }
118 :
119 : template<typename DataType_>
120 2 : void BlockLMSFilter<DataType_>::set_size(gsl::index size)
121 : {
122 2 : if(size == 0)
123 : {
124 1 : throw RuntimeError("Size must be strictly positive");
125 : }
126 1 : auto block_size = impl->block_size;
127 1 : impl = std::make_unique<BlockLMSFilterImpl>(size);
128 1 : set_block_size(block_size);
129 1 : }
130 :
131 : template<typename DataType_>
132 1 : gsl::index BlockLMSFilter<DataType_>::get_size() const
133 : {
134 1 : return impl->wfft.size() / 2;
135 : }
136 :
137 : template<typename DataType_>
138 3 : void BlockLMSFilter<DataType_>::set_block_size(gsl::index size)
139 : {
140 3 : if (size == 0)
141 : {
142 1 : throw ATK::RuntimeError("Block size must be strictly positive");
143 : }
144 2 : impl->accumulate_block_size = 0;
145 2 : impl->block_size = size;
146 2 : impl->block_input.assign(2 * size, 0);
147 2 : impl->block_ref.assign(size, 0);
148 2 : impl->block_fft.assign(2 * size, 0);
149 2 : impl->block_fft2.assign(2 * size, 0);
150 2 : impl->block_ifft.assign(2 * size, 0);
151 2 : }
152 :
153 : template<typename DataType_>
154 1 : gsl::index BlockLMSFilter<DataType_>::get_block_size() const
155 : {
156 1 : return impl->block_size;
157 : }
158 :
159 : template<typename DataType_>
160 4 : void BlockLMSFilter<DataType_>::set_memory(double memory)
161 : {
162 4 : if (memory >= 1)
163 : {
164 1 : throw ATK::RuntimeError("Memory must be less than 1");
165 : }
166 3 : if (memory <= 0)
167 : {
168 1 : throw ATK::RuntimeError("Memory must be strictly positive");
169 : }
170 :
171 2 : impl->alpha = memory;
172 2 : }
173 :
174 : template<typename DataType_>
175 1 : double BlockLMSFilter<DataType_>::get_memory() const
176 : {
177 1 : return impl->alpha;
178 : }
179 :
180 : template<typename DataType_>
181 4 : void BlockLMSFilter<DataType_>::set_mu(double mu)
182 : {
183 4 : if (mu >= 1)
184 : {
185 1 : throw ATK::RuntimeError("Mu must be less than 1");
186 : }
187 3 : if (mu <= 0)
188 : {
189 1 : throw ATK::RuntimeError("Mu must be strictly positive");
190 : }
191 :
192 2 : impl->mu = mu;
193 2 : }
194 :
195 : template<typename DataType_>
196 1 : double BlockLMSFilter<DataType_>::get_mu() const
197 : {
198 1 : return impl->mu;
199 : }
200 :
201 : template<typename DataType_>
202 2 : void BlockLMSFilter<DataType_>::process_impl(gsl::index size) const
203 : {
204 2 : const DataType* ATK_RESTRICT input = converted_inputs[0];
205 2 : const DataType* ATK_RESTRICT ref = converted_inputs[1];
206 2 : DataType* ATK_RESTRICT output = outputs[0];
207 :
208 1302 : for(gsl::index i = 0; i < size; ++i)
209 : {
210 1300 : impl->update(input[i], ref[i], output[i]);
211 : }
212 2 : }
213 :
214 : template<typename DataType_>
215 1 : const std::complex<double>* BlockLMSFilter<DataType_>::get_w() const
216 : {
217 1 : return impl->wfft.data();
218 : }
219 :
220 : template<typename DataType_>
221 0 : void BlockLMSFilter<DataType_>::set_w(gsl::not_null<const std::complex<double>*> w)
222 : {
223 0 : impl->wfft = Eigen::Map<const typename BlockLMSFilterImpl::cwType>(w.get(), get_size() * 2);
224 0 : }
225 :
226 : template<typename DataType_>
227 1 : void BlockLMSFilter<DataType_>::set_learning(bool learning)
228 : {
229 1 : impl->learning = learning;
230 1 : }
231 :
232 : template<typename DataType_>
233 2 : bool BlockLMSFilter<DataType_>::get_learning() const
234 : {
235 2 : return impl->learning;
236 : }
237 :
238 : template class BlockLMSFilter<double>;
239 : #if ATK_ENABLE_INSTANTIATION
240 : template class BlockLMSFilter<std::complex<double>>;
241 : #endif
242 : }
|