LCOV - code coverage report
Current view: top level - Utility - ScalarNewtonRaphson.h (source / functions) Hit Total Coverage
Test: coverage.info.cleaned Lines: 23 28 82.1 %
Date: 2021-02-18 20:07:22 Functions: 14 14 100.0 %

          Line data    Source code
       1             : /**
       2             :  * \file ScalarNewtonRaphson.h
       3             :  */
       4             : 
       5             : #ifndef ATK_UTILITY_SCALAR_NEWTONRAPHSON_H
       6             : #define ATK_UTILITY_SCALAR_NEWTONRAPHSON_H
       7             : 
       8             : #include <ATK/config.h>
       9             : 
      10             : #include <cmath>
      11             : #if ATK_PROFILING == 1
      12             : #include <iostream>
      13             : #endif
      14             : #include <limits>
      15             : 
      16             : namespace ATK
      17             : {
      18             :   /// Scalar Newton Raphson optimizer
      19             :   /*!
      20             :    * A NR optimizer, 10 iterations max
      21             :    */
      22             :   template<typename Function, int max_iterations=10, bool check_convergence=true>
      23             :   class ScalarNewtonRaphson
      24             :   {
      25             :     using DataType = typename Function::DataType;
      26             :     
      27             :     Function function;
      28             :     
      29             :     DataType precision;
      30             :     DataType maxstep{1};
      31             :     
      32             : #if ATK_PROFILING == 1
      33             :     int64_t nb_iterations{0};
      34             :     int64_t nb_optimizations{0};
      35             : #endif
      36             :     
      37             :   public:
      38             :     /*!
      39             :      * @brief Constructs the optimizer
      40             :      * @param function is the function that we will try to optimize.
      41             :      *   It is a functor taking x[n-1], x[n], y[n-1] and an estimate y[n], returning the value of the cost function and its derivative according to y[n]
      42             :      * @param precision is the precision that the optimizer will try to achieve. By default uses $$\\sqrt{\\epsilon_{Datatype}}$$
      43             :      */
      44           6 :     ScalarNewtonRaphson(Function&& function, DataType precision = 0)
      45           6 :     :function(std::move(function)), precision(precision)
      46             :     {
      47           6 :       if(precision == 0)
      48             :       {
      49           6 :         this->precision = std::sqrt(std::numeric_limits<DataType>::epsilon());
      50             :       }
      51           6 :     }
      52             : 
      53             : #if ATK_PROFILING == 1
      54             :     ~ScalarNewtonRaphson()
      55             :     {
      56             :       std::cout << "nb optimizations: " << nb_optimizations << std::endl;
      57             :       std::cout << "nb iterations: " << nb_iterations << std::endl;
      58             :       std::cout << "average: " << nb_iterations / double(nb_optimizations) << std::endl;
      59             :     }
      60             : #endif
      61             :     
      62             :     ScalarNewtonRaphson(const ScalarNewtonRaphson&) = delete;
      63             :     ScalarNewtonRaphson& operator=(const ScalarNewtonRaphson&) = delete;
      64             : 
      65             :     /// Optimize the function and sets its internal state
      66       17000 :     void optimize(const DataType* ATK_RESTRICT input, DataType* ATK_RESTRICT output)
      67             :     {
      68             : #if ATK_PROFILING == 1
      69             :       ++nb_optimizations;
      70             : #endif
      71       17000 :       output[0] = optimize_impl(input, output);
      72       17000 :     }
      73             : 
      74             :     /// Returns the function
      75        8004 :     Function& get_function()
      76             :     {
      77        8004 :       return function;
      78             :     }
      79             : 
      80             :     /// Returns the function
      81             :     const Function& get_function() const
      82             :     {
      83             :       return function;
      84             :     }
      85             : 
      86             :   protected:
      87             :     /// Just optimize the function
      88       17000 :     DataType optimize_impl(const DataType* ATK_RESTRICT input, DataType* ATK_RESTRICT output)
      89             :     {
      90       17000 :       DataType y1 = function.estimate(input, output);
      91             :       gsl::index i;
      92             :       
      93       29737 :       for(i = 0; i < max_iterations; ++i)
      94             :       {
      95             : #if ATK_PROFILING == 1
      96             :         ++nb_iterations;
      97             : #endif
      98       29737 :         std::pair<DataType, DataType> all = function(input, output, y1);
      99       29737 :         if(std::abs(all.second) < std::numeric_limits<DataType>::epsilon() )
     100             :         {
     101       17000 :           return y1;
     102             :         }
     103       29737 :         DataType cx = all.first / all.second;
     104       29737 :         if(cx < -maxstep)
     105             :         {
     106           0 :           cx = -maxstep;
     107             :         }
     108       29737 :         if(cx > maxstep)
     109             :         {
     110           0 :           cx = maxstep;
     111             :         }
     112       29737 :         DataType yk = y1 - cx;
     113       29737 :         if( std::abs(yk - y1) < precision )
     114             :         {
     115       17000 :           return yk;
     116             :         }
     117       12737 :         y1 = yk;
     118             :       }
     119           0 :       if(check_convergence && i == max_iterations)
     120             :       {
     121           0 :         return output[-1]; // Stay the same
     122             :       }
     123           0 :       return y1;
     124             :     }
     125             :   };
     126             : }
     127             : 
     128             : #endif

Generated by: LCOV version TK-3.3.0-4-gdba42eea