Line data Source code
1 : /** 2 : * \file ScalarNewtonRaphson.h 3 : */ 4 : 5 : #ifndef ATK_UTILITY_SCALAR_NEWTONRAPHSON_H 6 : #define ATK_UTILITY_SCALAR_NEWTONRAPHSON_H 7 : 8 : #include <ATK/config.h> 9 : 10 : #include <cmath> 11 : #if ATK_PROFILING == 1 12 : #include <iostream> 13 : #endif 14 : #include <limits> 15 : 16 : namespace ATK 17 : { 18 : /// Scalar Newton Raphson optimizer 19 : /*! 20 : * A NR optimizer, 10 iterations max 21 : */ 22 : template<typename Function, int max_iterations=10, bool check_convergence=true> 23 : class ScalarNewtonRaphson 24 : { 25 : using DataType = typename Function::DataType; 26 : 27 : Function function; 28 : 29 : DataType precision; 30 : DataType maxstep{1}; 31 : 32 : #if ATK_PROFILING == 1 33 : int64_t nb_iterations{0}; 34 : int64_t nb_optimizations{0}; 35 : #endif 36 : 37 : public: 38 : /*! 39 : * @brief Constructs the optimizer 40 : * @param function is the function that we will try to optimize. 41 : * It is a functor taking x[n-1], x[n], y[n-1] and an estimate y[n], returning the value of the cost function and its derivative according to y[n] 42 : * @param precision is the precision that the optimizer will try to achieve. By default uses $$\\sqrt{\\epsilon_{Datatype}}$$ 43 : */ 44 6 : ScalarNewtonRaphson(Function&& function, DataType precision = 0) 45 6 : :function(std::move(function)), precision(precision) 46 : { 47 6 : if(precision == 0) 48 : { 49 6 : this->precision = std::sqrt(std::numeric_limits<DataType>::epsilon()); 50 : } 51 6 : } 52 : 53 : #if ATK_PROFILING == 1 54 : ~ScalarNewtonRaphson() 55 : { 56 : std::cout << "nb optimizations: " << nb_optimizations << std::endl; 57 : std::cout << "nb iterations: " << nb_iterations << std::endl; 58 : std::cout << "average: " << nb_iterations / double(nb_optimizations) << std::endl; 59 : } 60 : #endif 61 : 62 : ScalarNewtonRaphson(const ScalarNewtonRaphson&) = delete; 63 : ScalarNewtonRaphson& operator=(const ScalarNewtonRaphson&) = delete; 64 : 65 : /// Optimize the function and sets its internal state 66 17000 : void optimize(const DataType* ATK_RESTRICT input, DataType* ATK_RESTRICT output) 67 : { 68 : #if ATK_PROFILING == 1 69 : ++nb_optimizations; 70 : #endif 71 17000 : output[0] = optimize_impl(input, output); 72 17000 : } 73 : 74 : /// Returns the function 75 8004 : Function& get_function() 76 : { 77 8004 : return function; 78 : } 79 : 80 : /// Returns the function 81 : const Function& get_function() const 82 : { 83 : return function; 84 : } 85 : 86 : protected: 87 : /// Just optimize the function 88 17000 : DataType optimize_impl(const DataType* ATK_RESTRICT input, DataType* ATK_RESTRICT output) 89 : { 90 17000 : DataType y1 = function.estimate(input, output); 91 : gsl::index i; 92 : 93 29737 : for(i = 0; i < max_iterations; ++i) 94 : { 95 : #if ATK_PROFILING == 1 96 : ++nb_iterations; 97 : #endif 98 29737 : std::pair<DataType, DataType> all = function(input, output, y1); 99 29737 : if(std::abs(all.second) < std::numeric_limits<DataType>::epsilon() ) 100 : { 101 17000 : return y1; 102 : } 103 29737 : DataType cx = all.first / all.second; 104 29737 : if(cx < -maxstep) 105 : { 106 0 : cx = -maxstep; 107 : } 108 29737 : if(cx > maxstep) 109 : { 110 0 : cx = maxstep; 111 : } 112 29737 : DataType yk = y1 - cx; 113 29737 : if( std::abs(yk - y1) < precision ) 114 : { 115 17000 : return yk; 116 : } 117 12737 : y1 = yk; 118 : } 119 0 : if(check_convergence && i == max_iterations) 120 : { 121 0 : return output[-1]; // Stay the same 122 : } 123 0 : return y1; 124 : } 125 : }; 126 : } 127 : 128 : #endif