LCOV - code coverage report
Current view: top level - Utility - VectorizedNewtonRaphson.h (source / functions) Hit Total Coverage
Test: coverage.info.cleaned Lines: 25 26 96.2 %
Date: 2021-02-18 20:07:22 Functions: 32 56 57.1 %

          Line data    Source code
       1             : /**
       2             :  * \file VectorizedNewtonRaphson.h
       3             :  */
       4             : 
       5             : #ifndef ATK_UTILITY_VECTORIZEDNEWTONRAPHSON_H
       6             : #define ATK_UTILITY_VECTORIZEDNEWTONRAPHSON_H
       7             : 
       8             : #include <cmath>
       9             : #if ATK_PROFILING == 1
      10             : #include <iostream>
      11             : #endif
      12             : #include <limits>
      13             : 
      14             : #include <ATK/config.h>
      15             : 
      16             : #include <Eigen/Dense>
      17             : 
      18             : namespace ATK
      19             : {
      20             :   /// Vectorized Newton Raphson optimizer
      21             :   /*!
      22             :    * A NR optimizer, 10 iterations max
      23             :    */
      24             :   template<typename Function, gsl::index size, gsl::index max_iterations=10, bool check_convergence=true>
      25             :   class VectorizedNewtonRaphson
      26             :   {
      27             :     using DataType = typename Function::DataType;
      28             :     
      29             :     Function function;
      30             :     
      31             :     DataType precision;
      32             :     
      33             : #if ATK_PROFILING == 1
      34             :     int64_t nb_iterations{0};
      35             :     int64_t nb_optimizations{0};
      36             : #endif
      37             :     using Vector = Eigen::Matrix<DataType, size, 1>;
      38             :     
      39             :   public:
      40             :     /*!
      41             :      * @brief Constructs the optimizer
      42             :      * @param function is the function that we will try to optimize.
      43             :      *   It is a functor taking x[n-1], x[n], y[n-1] and an estimate y[n], returning the value of the cost function and its derivative according to y[n]
      44             :      * @param precision is the precision that the optimizer will try to achieve. By default uses $$\\sqrt{\\epsilon_{Datatype}}$$
      45             :      */
      46          12 :     VectorizedNewtonRaphson(Function&& function, DataType precision = 0)
      47          12 :     :function(std::move(function)), precision(precision)
      48             :     {
      49          12 :       if(precision == 0)
      50             :       {
      51          12 :         this->precision = std::sqrt(std::numeric_limits<DataType>::epsilon());
      52             :       }
      53          12 :     }
      54             :     
      55             : #if ATK_PROFILING == 1
      56             :     ~VectorizedNewtonRaphson()
      57             :     {
      58             :       std::cout << "nb optimizations: " << nb_optimizations << std::endl;
      59             :       std::cout << "nb iterations: " << nb_iterations << std::endl;
      60             :       std::cout << "average: " << nb_iterations / double(nb_optimizations) << std::endl;
      61             :     }
      62             : #endif
      63             :     
      64             :     VectorizedNewtonRaphson(const VectorizedNewtonRaphson&) = delete;
      65             :     VectorizedNewtonRaphson& operator=(const VectorizedNewtonRaphson&) = delete;
      66             : 
      67             :     /// Optimize the function and sets its internal state
      68       74400 :     void optimize(gsl::index i, const DataType* const * ATK_RESTRICT input, DataType* const * ATK_RESTRICT output)
      69             :     {
      70             : #if ATK_PROFILING == 1
      71             :       ++nb_optimizations;
      72             : #endif
      73       74400 :       auto res = optimize_impl(i, input, output);
      74      376800 :       for(gsl::index j = 0; j < size; ++j)
      75             :       {
      76      302400 :         output[j][i] = res.data()[j];
      77             :       }
      78       74400 :     }
      79             : 
      80             :     /// Returns the function
      81       74400 :     Function& get_function()
      82             :     {
      83       74400 :       return function;
      84             :     }
      85             : 
      86             :     /// Returns the function
      87             :     const Function& get_function() const
      88             :     {
      89             :       return function;
      90             :     }
      91             : 
      92             :   protected:
      93             :     /// Just optimize the function
      94       74400 :     Vector optimize_impl(gsl::index i, const DataType* const * ATK_RESTRICT input, DataType* const * ATK_RESTRICT output)
      95             :     {
      96       74400 :       Vector y1 = function.estimate(i, input, output);
      97             :       
      98             :       gsl::index j;
      99      172125 :       for(j = 0; j < max_iterations; ++j)
     100             :       {
     101             : #if ATK_PROFILING == 1
     102             :         ++nb_iterations;
     103             : #endif
     104      172123 :         auto cx = function(i, input, output, y1);
     105      172123 :         auto yk = y1 - cx;
     106      172123 :         if((cx.array().abs() < precision).all())
     107             :         {
     108       74398 :           return yk;
     109             :         }
     110       97725 :         y1 = yk;
     111             :       }
     112           2 :       if(check_convergence && j == max_iterations)
     113             :       {
     114           2 :         Vector y0;
     115          10 :         for(gsl::index j = 0; j < size; ++j)
     116             :         {
     117           8 :           y0.data()[j] = output[j][i-1];
     118             :         }
     119             : 
     120           2 :         return y0; // Stay the same
     121             :       }
     122           0 :       return y1;
     123             :     }
     124             :   };
     125             : }
     126             : 
     127             : #endif

Generated by: LCOV version TK-3.3.0-4-gdba42eea