Line data Source code
1 : /**
2 : * \file RemezBasedFilter.cpp
3 : */
4 :
5 : #include "RemezBasedFilter.h"
6 : #include <ATK/Core/Utilities.h>
7 : #include <ATK/EQ/FIRFilter.h>
8 : #include <ATK/Utility/FFT.h>
9 :
10 : #include <boost/math/constants/constants.hpp>
11 :
12 : #include <Eigen/Dense>
13 :
14 : #include <vector>
15 :
16 : namespace
17 : {
18 : template<class DataType>
19 : class RemezBuilder
20 : {
21 : public:
22 : using AlignedScalarVector = typename ATK::TypedBaseFilter<DataType>::AlignedScalarVector;
23 : private:
24 : const static gsl::index grid_size = 1024; // grid size, power of two better for FFT
25 : constexpr static DataType SN = 1e-8;
26 :
27 : gsl::index M;
28 : std::vector<DataType> grid;
29 : std::vector<std::pair<std::pair<DataType, DataType>, std::pair<DataType, DataType>> > target;
30 :
31 : /// Computed coefficients
32 : AlignedScalarVector coeffs;
33 : /// Selected indices
34 : std::vector<gsl::index> indices;
35 : /// Weight function on the grid
36 : std::vector<DataType> weights;
37 : /// Objective function on the grid
38 : std::vector<DataType> objective;
39 : /// Alternate signs
40 : std::vector<int> s;
41 :
42 : ATK::FFT<DataType> fft_processor;
43 :
44 : public:
45 3 : RemezBuilder(gsl::index order, const std::vector<std::pair<std::pair<DataType, DataType>, std::pair<DataType, DataType>> >& target)
46 3 : :M(order / 2), target(target)
47 : {
48 3 : grid.resize(grid_size);
49 3075 : for(gsl::index i = 0; i < grid_size; ++i)
50 : {
51 3072 : grid[i] = i * boost::math::constants::pi<DataType>() / grid_size;
52 : }
53 3 : fft_processor.set_size(2 * grid_size);
54 3 : }
55 :
56 1 : void init()
57 : {
58 1 : coeffs.assign(M * 2 + 1, 0);
59 1 : indices.assign(M + 2, -1);
60 1 : s.assign(M+2, 0);
61 :
62 1 : weights.assign(grid_size, 0);
63 1 : objective.assign(grid_size, 0);
64 :
65 1 : int current_template = 0;
66 1025 : for(gsl::index i = 0; i < grid_size; ++i)
67 : {
68 1024 : auto reduced_freq = grid[i] / boost::math::constants::pi<DataType>();
69 1024 : if(reduced_freq > target[current_template].first.second && current_template + 1 < target.size())
70 : {
71 1 : ++current_template;
72 : }
73 1024 : if (reduced_freq < target[current_template].first.first || reduced_freq > target[current_template].first.second)
74 : {
75 102 : weights[i] = 0;
76 102 : objective[i] = 0;
77 : }
78 : else
79 : {
80 922 : weights[i] = target[current_template].second.second;
81 922 : objective[i] = target[current_template].second.first;
82 : }
83 : }
84 1 : int flag = -1;
85 9 : for (gsl::index i = 0; i < M + 2; ++i)
86 : {
87 8 : s[i] = flag;
88 8 : flag = -flag;
89 : }
90 1 : indices = set_starting_conditions();
91 1 : }
92 :
93 1 : std::vector<gsl::index> set_starting_conditions() const
94 : {
95 1 : std::vector<gsl::index> indices;
96 :
97 2 : std::vector<gsl::index> valid_indices;
98 1025 : for (gsl::index i = 0; i < grid_size; ++i)
99 : {
100 1024 : if (weights[i] != 0)
101 : {
102 922 : valid_indices.push_back(i);
103 : }
104 : }
105 :
106 9 : for (gsl::index i = 0; i < M + 2; ++i)
107 : {
108 8 : indices.push_back(valid_indices[std::lround(valid_indices.size() / (M + 4.) * (i + 1))]);
109 : }
110 :
111 2 : return indices;
112 : }
113 :
114 3 : AlignedScalarVector build()
115 : {
116 3 : if(target.empty())
117 : {
118 2 : coeffs.clear();
119 2 : return coeffs;
120 : }
121 1 : init();
122 4 : while(true)
123 : {
124 5 : Eigen::Matrix<DataType, Eigen::Dynamic, Eigen::Dynamic> A(M + 2, M + 2);
125 45 : for (gsl::index i = 0; i < M + 2; ++i)
126 : {
127 320 : for (gsl::index j = 0; j < M + 1; ++j)
128 : {
129 280 : A(i, j) = std::cos(grid[indices[i]] * j);
130 : }
131 : }
132 45 : for (gsl::index i = 0; i < M + 2; ++i)
133 : {
134 40 : A(i, M + 1) = s[i] / weights[indices[i]];
135 : }
136 :
137 5 : Eigen::Matrix<DataType, Eigen::Dynamic, 1> b(M + 2, 1);
138 45 : for (gsl::index i = 0; i < M + 2; ++i)
139 : {
140 40 : b(i) = objective[indices[i]];
141 : }
142 5 : Eigen::Matrix<DataType, Eigen::Dynamic, 1> x = A.colPivHouseholderQr().solve(b);
143 :
144 35 : for (gsl::index i = 0; i < M; ++i)
145 : {
146 30 : coeffs[i] = coeffs[2 * M - i] = x[M - i] / 2;
147 : }
148 5 : coeffs[M] = x[0];
149 5 : auto delta = std::abs(x[M + 1]); // maximum cost
150 :
151 5 : auto newerr = compute_new_error();
152 5 : auto new_indices = locmax(newerr);
153 :
154 5 : filter_SN(delta, new_indices, newerr);
155 5 : filter_monotony(new_indices, newerr);
156 :
157 5 : DataType max = 0;
158 45 : for (auto indice : new_indices)
159 : {
160 40 : max = std::max(max, std::abs(newerr[indice]));
161 : }
162 :
163 5 : if ((max - delta) / delta < SN)
164 : {
165 1 : break;
166 : }
167 :
168 4 : indices = std::move(new_indices);
169 : }
170 :
171 1 : return coeffs;
172 : }
173 :
174 : private:
175 : /// Creates a spectral response for a given set of coeffs through an FFT
176 5 : std::vector<DataType> firamp() const
177 : {
178 10 : std::vector<std::complex<DataType>> output(grid_size * 2);
179 5 : fft_processor.process_forward(coeffs.data(), output.data(), coeffs.size());
180 5 : std::vector<DataType> amp(grid_size);
181 5125 : for (gsl::index i = 0; i < grid_size; ++i)
182 : {
183 5120 : amp[i] = (std::complex<DataType>(std::cos(M * grid[i]), std::sin(M * grid[i])) * output[i]).real() * grid_size * 2;
184 : }
185 10 : return amp;
186 : }
187 :
188 5 : std::vector<DataType> compute_new_error() const
189 : {
190 10 : auto fir_result = firamp();
191 5 : std::vector<DataType> newerr(grid_size);
192 5125 : for (gsl::index i = 0; i < grid_size; ++i)
193 : {
194 5120 : newerr[i] = (fir_result[i] - objective[i]) * weights[i];
195 : }
196 :
197 10 : return newerr;
198 : }
199 :
200 : // Finds min and max
201 5 : std::vector<gsl::index> locmax(const std::vector<DataType>& data) const
202 : {
203 5 : std::vector<gsl::index> v;
204 :
205 10 : std::vector<DataType> temp1;
206 10 : std::vector<DataType> temp2;
207 5 : temp1.push_back(data[0] - 1);
208 5120 : for (gsl::index i = 0; i < data.size() - 1; ++i)
209 : {
210 5115 : temp1.push_back(data[i]);
211 5115 : temp2.push_back(data[i + 1]);
212 : }
213 5 : temp2.push_back(data.back() - 1);
214 :
215 5125 : for (gsl::index i = 0; i < data.size(); ++i)
216 : {
217 5120 : if ((data[i] > temp1[i]) && (data[i] > temp2[i]))
218 : {
219 25 : v.push_back(i);
220 : }
221 : }
222 :
223 5 : temp1.clear();
224 5 : temp2.clear();
225 5 : temp1.push_back(-data[0] - 1);
226 5120 : for (gsl::index i = 0; i < data.size() - 1; ++i)
227 : {
228 5115 : temp1.push_back(-data[i]);
229 5115 : temp2.push_back(-data[i + 1]);
230 : }
231 5 : temp2.push_back(-data.back() - 1);
232 :
233 5125 : for (gsl::index i = 0; i < data.size(); ++i)
234 : {
235 5120 : if ((-data[i] > temp1[i]) && (-data[i] > temp2[i]))
236 : {
237 20 : v.push_back(i);
238 : }
239 : }
240 :
241 5 : std::sort(v.begin(), v.end());
242 :
243 10 : return v;
244 : }
245 :
246 5 : void filter_SN(DataType delta, std::vector<gsl::index>& indices, const std::vector<DataType>& err) const
247 : {
248 10 : std::vector<gsl::index> new_indices;
249 :
250 50 : for (auto indice : indices)
251 : {
252 45 : if (std::abs(err[indice]) > (delta - SN))
253 : {
254 41 : new_indices.push_back(indice);
255 : }
256 : }
257 :
258 5 : indices = std::move(new_indices);
259 5 : }
260 :
261 5 : std::vector<gsl::index> etap(const std::vector<DataType>& data) const
262 : {
263 5 : std::vector<gsl::index> v;
264 5 : auto xe = data[0];
265 5 : gsl::index xv = 0;
266 41 : for (gsl::index i = 1; i < data.size(); ++i)
267 : {
268 36 : if (std::signbit(data[i]) == std::signbit(xe))
269 : {
270 1 : if (std::abs(data[i]) > std::abs(xe))
271 : {
272 0 : xe = data[i];
273 0 : xv = i;
274 : }
275 : }
276 : else
277 : {
278 35 : v.push_back(xv);
279 35 : xe = data[i];
280 35 : xv = i;
281 : }
282 : }
283 5 : v.push_back(xv);
284 10 : return v;
285 : }
286 :
287 5 : void filter_monotony(std::vector<gsl::index>& indices, const std::vector<DataType>& err) const
288 : {
289 10 : std::vector<DataType> filtered_err;
290 46 : for (auto indice : indices)
291 : {
292 41 : filtered_err.push_back(err[indice]);
293 : }
294 :
295 10 : auto selected_indices = etap(filtered_err);
296 10 : std::vector<gsl::index> new_indices;
297 45 : for (gsl::index i = 0; i < M + 2; ++i)
298 : {
299 40 : new_indices.push_back(indices[selected_indices[i]]);
300 : }
301 5 : indices = std::move(new_indices);
302 5 : }
303 : };
304 : }
305 :
306 : namespace ATK
307 : {
308 : template<class DataType>
309 3 : RemezBasedCoefficients<DataType>::RemezBasedCoefficients(gsl::index nb_channels)
310 3 : :Parent(nb_channels, nb_channels)
311 : {
312 3 : }
313 :
314 : template<class DataType>
315 0 : RemezBasedCoefficients<DataType>::RemezBasedCoefficients(RemezBasedCoefficients&& other)
316 0 : :Parent(std::move(other)), target(std::move(other.target)), in_order(std::move(other.in_order)), coefficients_in(std::move(other.coefficients_in))
317 : {
318 0 : }
319 :
320 : template<class DataType>
321 3 : void RemezBasedCoefficients<DataType>::set_template(const std::vector<std::pair<std::pair<CoeffDataType, CoeffDataType>, std::pair<CoeffDataType, CoeffDataType> > >& target)
322 : {
323 3 : this->target = target;
324 3 : setup();
325 2 : }
326 :
327 : template<class DataType>
328 0 : const std::vector<std::pair<std::pair<typename RemezBasedCoefficients<DataType>::CoeffDataType, typename RemezBasedCoefficients<DataType>::CoeffDataType>, std::pair<typename RemezBasedCoefficients<DataType>::CoeffDataType, typename RemezBasedCoefficients<DataType>::CoeffDataType> > >& RemezBasedCoefficients<DataType>::get_template() const
329 : {
330 0 : return target;
331 : }
332 :
333 : template<class DataType>
334 3 : void RemezBasedCoefficients<DataType>::set_order(gsl::index order)
335 : {
336 3 : if(order % 2 == 1)
337 : {
338 1 : throw ATK::RuntimeError("Need an even filter order (considering order 0 has 1 coefficients)");
339 : }
340 2 : in_order = order;
341 2 : setup();
342 2 : }
343 :
344 : template<class DataType>
345 11 : void RemezBasedCoefficients<DataType>::setup()
346 : {
347 11 : Parent::setup();
348 :
349 11 : std::sort(target.begin(), target.end());
350 13 : for(gsl::index i = 0; i + 1 < target.size(); ++i)
351 : {
352 3 : if(target[i].first.second > target[i + 1].first.first)
353 : {
354 1 : target.clear();
355 1 : throw ATK::RuntimeError("Bad template");
356 : }
357 : }
358 :
359 10 : if (in_order > 0)
360 : {
361 3 : RemezBuilder<CoeffDataType> builder(in_order, target);
362 3 : coefficients_in = builder.build();
363 : }
364 10 : }
365 :
366 : template class ATK_EQ_EXPORT RemezBasedCoefficients<double>;
367 : #if ATK_ENABLE_INSTANTIATION
368 : template class ATK_EQ_EXPORT RemezBasedCoefficients<std::complex<double> >;
369 : #endif
370 : }
|