LCOV - code coverage report
Current view: top level - EQ - RemezBasedFilter.cpp (source / functions) Hit Total Coverage
Test: coverage.info.cleaned Lines: 159 166 95.8 %
Date: 2021-02-18 20:07:22 Functions: 14 16 87.5 %

          Line data    Source code
       1             : /**
       2             :  * \file RemezBasedFilter.cpp
       3             :  */
       4             : 
       5             : #include "RemezBasedFilter.h"
       6             : #include <ATK/Core/Utilities.h>
       7             : #include <ATK/EQ/FIRFilter.h>
       8             : #include <ATK/Utility/FFT.h>
       9             : 
      10             : #include <boost/math/constants/constants.hpp>
      11             : 
      12             : #include <Eigen/Dense>
      13             : 
      14             : #include <vector>
      15             : 
      16             : namespace
      17             : {
      18             :   template<class DataType>
      19             :   class RemezBuilder
      20             :   {
      21             :   public:
      22             :     using AlignedScalarVector = typename ATK::TypedBaseFilter<DataType>::AlignedScalarVector;
      23             :   private:
      24             :     const static gsl::index grid_size = 1024; // grid size, power of two better for FFT
      25             :     constexpr static DataType SN = 1e-8;
      26             : 
      27             :     gsl::index M;
      28             :     std::vector<DataType> grid;
      29             :     std::vector<std::pair<std::pair<DataType, DataType>, std::pair<DataType, DataType>> > target;
      30             :     
      31             :     /// Computed coefficients
      32             :     AlignedScalarVector coeffs;
      33             :     /// Selected indices
      34             :     std::vector<gsl::index> indices;
      35             :     /// Weight function on the grid
      36             :     std::vector<DataType> weights;
      37             :     /// Objective function on the grid
      38             :     std::vector<DataType> objective;
      39             :     /// Alternate signs
      40             :     std::vector<int> s;
      41             : 
      42             :     ATK::FFT<DataType> fft_processor;
      43             : 
      44             :   public:
      45           3 :     RemezBuilder(gsl::index order, const std::vector<std::pair<std::pair<DataType, DataType>, std::pair<DataType, DataType>> >& target)
      46           3 :     :M(order / 2), target(target)
      47             :     {
      48           3 :       grid.resize(grid_size);
      49        3075 :       for(gsl::index i = 0; i < grid_size; ++i)
      50             :       {
      51        3072 :         grid[i] = i * boost::math::constants::pi<DataType>() / grid_size;
      52             :       }
      53           3 :       fft_processor.set_size(2 * grid_size);
      54           3 :     }
      55             :     
      56           1 :     void init()
      57             :     {
      58           1 :       coeffs.assign(M * 2 + 1, 0);
      59           1 :       indices.assign(M + 2, -1);
      60           1 :       s.assign(M+2, 0);
      61             :       
      62           1 :       weights.assign(grid_size, 0);
      63           1 :       objective.assign(grid_size, 0);
      64             : 
      65           1 :       int current_template = 0;
      66        1025 :       for(gsl::index i = 0; i < grid_size; ++i)
      67             :       {
      68        1024 :         auto reduced_freq = grid[i] / boost::math::constants::pi<DataType>();
      69        1024 :         if(reduced_freq > target[current_template].first.second && current_template + 1 < target.size())
      70             :         {
      71           1 :           ++current_template;
      72             :         }
      73        1024 :         if (reduced_freq < target[current_template].first.first || reduced_freq > target[current_template].first.second)
      74             :         {
      75         102 :           weights[i] = 0;
      76         102 :           objective[i] = 0;
      77             :         }
      78             :         else
      79             :         {
      80         922 :           weights[i] = target[current_template].second.second;
      81         922 :           objective[i] = target[current_template].second.first;
      82             :         }
      83             :       }
      84           1 :       int flag = -1;
      85           9 :       for (gsl::index i = 0; i < M + 2; ++i)
      86             :       {
      87           8 :         s[i] = flag;
      88           8 :         flag = -flag;
      89             :       }
      90           1 :       indices = set_starting_conditions();
      91           1 :     }
      92             : 
      93           1 :     std::vector<gsl::index> set_starting_conditions() const
      94             :     {
      95           1 :       std::vector<gsl::index> indices;
      96             : 
      97           2 :       std::vector<gsl::index> valid_indices;
      98        1025 :       for (gsl::index i = 0; i < grid_size; ++i)
      99             :       {
     100        1024 :         if (weights[i] != 0)
     101             :         {
     102         922 :           valid_indices.push_back(i);
     103             :         }
     104             :       }
     105             : 
     106           9 :       for (gsl::index i = 0; i < M + 2; ++i)
     107             :       {
     108           8 :         indices.push_back(valid_indices[std::lround(valid_indices.size() / (M + 4.) * (i + 1))]);
     109             :       }
     110             : 
     111           2 :       return indices;
     112             :     }
     113             :     
     114           3 :     AlignedScalarVector build()
     115             :     {
     116           3 :       if(target.empty())
     117             :       {
     118           2 :         coeffs.clear();
     119           2 :         return coeffs;
     120             :       }
     121           1 :       init();
     122           4 :       while(true)
     123             :       {
     124           5 :         Eigen::Matrix<DataType, Eigen::Dynamic, Eigen::Dynamic> A(M + 2, M + 2);
     125          45 :         for (gsl::index i = 0; i < M + 2; ++i)
     126             :         {
     127         320 :           for (gsl::index j = 0; j < M + 1; ++j)
     128             :           {
     129         280 :             A(i, j) = std::cos(grid[indices[i]] * j);
     130             :           }
     131             :         }
     132          45 :         for (gsl::index i = 0; i < M + 2; ++i)
     133             :         {
     134          40 :           A(i, M + 1) = s[i] / weights[indices[i]];
     135             :         }
     136             : 
     137           5 :         Eigen::Matrix<DataType, Eigen::Dynamic, 1> b(M + 2, 1);
     138          45 :         for (gsl::index i = 0; i < M + 2; ++i)
     139             :         {
     140          40 :           b(i) = objective[indices[i]];
     141             :         }
     142           5 :         Eigen::Matrix<DataType, Eigen::Dynamic, 1> x = A.colPivHouseholderQr().solve(b);
     143             : 
     144          35 :         for (gsl::index i = 0; i < M; ++i)
     145             :         {
     146          30 :           coeffs[i] = coeffs[2 * M - i] = x[M - i] / 2;
     147             :         }
     148           5 :         coeffs[M] = x[0];
     149           5 :         auto delta = std::abs(x[M + 1]); // maximum cost
     150             : 
     151           5 :         auto newerr = compute_new_error();
     152           5 :         auto new_indices = locmax(newerr);
     153             : 
     154           5 :         filter_SN(delta, new_indices, newerr);
     155           5 :         filter_monotony(new_indices, newerr);
     156             : 
     157           5 :         DataType max = 0;
     158          45 :         for (auto indice : new_indices)
     159             :         {
     160          40 :           max = std::max(max, std::abs(newerr[indice]));
     161             :         }
     162             : 
     163           5 :         if ((max - delta) / delta < SN)
     164             :         {
     165           1 :           break;
     166             :         }
     167             : 
     168           4 :         indices = std::move(new_indices);
     169             :       }
     170             : 
     171           1 :       return coeffs;
     172             :     }
     173             :     
     174             :   private:
     175             :     /// Creates a spectral response for a given set of coeffs through an FFT
     176           5 :     std::vector<DataType> firamp() const
     177             :     {
     178          10 :       std::vector<std::complex<DataType>> output(grid_size * 2);
     179           5 :       fft_processor.process_forward(coeffs.data(), output.data(), coeffs.size());
     180           5 :       std::vector<DataType> amp(grid_size);
     181        5125 :       for (gsl::index i = 0; i < grid_size; ++i)
     182             :       {
     183        5120 :         amp[i] = (std::complex<DataType>(std::cos(M * grid[i]), std::sin(M * grid[i])) * output[i]).real() * grid_size * 2;
     184             :       }
     185          10 :       return amp;
     186             :     }
     187             : 
     188           5 :     std::vector<DataType> compute_new_error() const
     189             :     {
     190          10 :       auto fir_result = firamp();
     191           5 :       std::vector<DataType> newerr(grid_size);
     192        5125 :       for (gsl::index i = 0; i < grid_size; ++i)
     193             :       {
     194        5120 :         newerr[i] = (fir_result[i] - objective[i]) * weights[i];
     195             :       }
     196             : 
     197          10 :       return newerr;
     198             :     }
     199             : 
     200             :     // Finds min and max
     201           5 :     std::vector<gsl::index> locmax(const std::vector<DataType>& data) const
     202             :     {
     203           5 :       std::vector<gsl::index> v;
     204             : 
     205          10 :       std::vector<DataType> temp1;
     206          10 :       std::vector<DataType> temp2;
     207           5 :       temp1.push_back(data[0] - 1);
     208        5120 :       for (gsl::index i = 0; i < data.size() - 1; ++i)
     209             :       {
     210        5115 :         temp1.push_back(data[i]);
     211        5115 :         temp2.push_back(data[i + 1]);
     212             :       }
     213           5 :       temp2.push_back(data.back() - 1);
     214             : 
     215        5125 :       for (gsl::index i = 0; i < data.size(); ++i)
     216             :       {
     217        5120 :         if ((data[i] > temp1[i]) && (data[i] > temp2[i]))
     218             :         {
     219          25 :           v.push_back(i);
     220             :         }
     221             :       }
     222             : 
     223           5 :       temp1.clear();
     224           5 :       temp2.clear();
     225           5 :       temp1.push_back(-data[0] - 1);
     226        5120 :       for (gsl::index i = 0; i < data.size() - 1; ++i)
     227             :       {
     228        5115 :         temp1.push_back(-data[i]);
     229        5115 :         temp2.push_back(-data[i + 1]);
     230             :       }
     231           5 :       temp2.push_back(-data.back() - 1);
     232             : 
     233        5125 :       for (gsl::index i = 0; i < data.size(); ++i)
     234             :       {
     235        5120 :         if ((-data[i] > temp1[i]) && (-data[i] > temp2[i]))
     236             :         {
     237          20 :           v.push_back(i);
     238             :         }
     239             :       }
     240             : 
     241           5 :       std::sort(v.begin(), v.end());
     242             : 
     243          10 :       return v;
     244             :     }
     245             : 
     246           5 :     void filter_SN(DataType delta, std::vector<gsl::index>& indices, const std::vector<DataType>& err) const
     247             :     {
     248          10 :       std::vector<gsl::index> new_indices;
     249             : 
     250          50 :       for (auto indice : indices)
     251             :       {
     252          45 :         if (std::abs(err[indice]) > (delta - SN))
     253             :         {
     254          41 :           new_indices.push_back(indice);
     255             :         }
     256             :       }
     257             : 
     258           5 :       indices = std::move(new_indices);
     259           5 :     }
     260             : 
     261           5 :     std::vector<gsl::index> etap(const std::vector<DataType>& data) const
     262             :     {
     263           5 :       std::vector<gsl::index> v;
     264           5 :       auto xe = data[0];
     265           5 :       gsl::index xv = 0;
     266          41 :       for (gsl::index i = 1; i < data.size(); ++i)
     267             :       {
     268          36 :         if (std::signbit(data[i]) == std::signbit(xe))
     269             :         {
     270           1 :           if (std::abs(data[i]) > std::abs(xe))
     271             :           {
     272           0 :             xe = data[i];
     273           0 :             xv = i;
     274             :           }
     275             :         }
     276             :         else
     277             :         {
     278          35 :           v.push_back(xv);
     279          35 :           xe = data[i];
     280          35 :           xv = i;
     281             :         }
     282             :       }
     283           5 :       v.push_back(xv);
     284          10 :       return v;
     285             :     }
     286             : 
     287           5 :     void filter_monotony(std::vector<gsl::index>& indices, const std::vector<DataType>& err) const
     288             :     {
     289          10 :       std::vector<DataType> filtered_err;
     290          46 :       for (auto indice : indices)
     291             :       {
     292          41 :         filtered_err.push_back(err[indice]);
     293             :       }
     294             : 
     295          10 :       auto selected_indices = etap(filtered_err);
     296          10 :       std::vector<gsl::index> new_indices;
     297          45 :       for (gsl::index i = 0; i < M + 2; ++i)
     298             :       {
     299          40 :         new_indices.push_back(indices[selected_indices[i]]);
     300             :       }
     301           5 :       indices = std::move(new_indices);
     302           5 :     }
     303             :   };
     304             : }
     305             : 
     306             : namespace ATK
     307             : {
     308             :   template<class DataType>
     309           3 :   RemezBasedCoefficients<DataType>::RemezBasedCoefficients(gsl::index nb_channels)
     310           3 :     :Parent(nb_channels, nb_channels)
     311             :   {
     312           3 :   }
     313             : 
     314             :   template<class DataType>
     315           0 :   RemezBasedCoefficients<DataType>::RemezBasedCoefficients(RemezBasedCoefficients&& other)
     316           0 :     :Parent(std::move(other)), target(std::move(other.target)), in_order(std::move(other.in_order)), coefficients_in(std::move(other.coefficients_in))
     317             :   {
     318           0 :   }
     319             : 
     320             :   template<class DataType>
     321           3 :   void RemezBasedCoefficients<DataType>::set_template(const std::vector<std::pair<std::pair<CoeffDataType, CoeffDataType>, std::pair<CoeffDataType, CoeffDataType> > >& target)
     322             :   {
     323           3 :     this->target = target;
     324           3 :     setup();
     325           2 :   }
     326             :   
     327             :   template<class DataType>
     328           0 :   const std::vector<std::pair<std::pair<typename RemezBasedCoefficients<DataType>::CoeffDataType, typename RemezBasedCoefficients<DataType>::CoeffDataType>, std::pair<typename RemezBasedCoefficients<DataType>::CoeffDataType, typename RemezBasedCoefficients<DataType>::CoeffDataType> > >& RemezBasedCoefficients<DataType>::get_template() const
     329             :   {
     330           0 :     return target;
     331             :   }
     332             :   
     333             :   template<class DataType>
     334           3 :   void RemezBasedCoefficients<DataType>::set_order(gsl::index order)
     335             :   {
     336           3 :     if(order % 2 == 1)
     337             :     {
     338           1 :       throw ATK::RuntimeError("Need an even filter order (considering order 0 has 1 coefficients)");
     339             :     }
     340           2 :     in_order = order;
     341           2 :     setup();
     342           2 :   }
     343             : 
     344             :   template<class DataType>
     345          11 :   void RemezBasedCoefficients<DataType>::setup()
     346             :   {
     347          11 :     Parent::setup();
     348             :     
     349          11 :     std::sort(target.begin(), target.end());
     350          13 :     for(gsl::index i = 0; i + 1 < target.size(); ++i)
     351             :     {
     352           3 :       if(target[i].first.second > target[i + 1].first.first)
     353             :       {
     354           1 :         target.clear();
     355           1 :         throw ATK::RuntimeError("Bad template");
     356             :       }
     357             :     }
     358             :     
     359          10 :     if (in_order > 0)
     360             :     {
     361           3 :       RemezBuilder<CoeffDataType> builder(in_order, target);
     362           3 :       coefficients_in = builder.build();
     363             :     }
     364          10 :   }
     365             :   
     366             :   template class ATK_EQ_EXPORT RemezBasedCoefficients<double>;
     367             : #if ATK_ENABLE_INSTANTIATION
     368             :   template class ATK_EQ_EXPORT RemezBasedCoefficients<std::complex<double> >;
     369             : #endif
     370             : }

Generated by: LCOV version TK-3.3.0-4-gdba42eea