[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]

vigra/random_forest/rf_earlystopping.hxx
00001 #ifndef RF_EARLY_STOPPING_P_HXX
00002 #define RF_EARLY_STOPPING_P_HXX
00003 #include <cmath>
00004 #include "rf_common.hxx"
00005 
00006 namespace vigra
00007 {
00008 
00009 #if 0    
00010 namespace es_detail
00011 {
00012     template<class T>
00013     T power(T const & in, int n)
00014     {
00015         T result = NumericTraits<T>::one();
00016         for(int ii = 0; ii < n ;++ii)
00017             result *= in;
00018         return result;
00019     }
00020 }
00021 #endif
00022 
00023 /**Base class from which all EarlyStopping Functors derive.
00024  */
00025 class StopBase
00026 {
00027 protected:
00028     ProblemSpec<> ext_param_;
00029     int tree_count_ ;
00030     bool is_weighted_;
00031 
00032 public:
00033     template<class T>
00034     void set_external_parameters(ProblemSpec<T> const  &prob, int tree_count = 0, bool is_weighted = false)
00035     {
00036         ext_param_ = prob; 
00037         is_weighted_ = is_weighted;
00038         tree_count_ = tree_count;
00039     }
00040     
00041     /** called after the prediction of a tree was added to the total prediction
00042      * \param WeightIter Iterator to the weights delivered by current tree.
00043      * \param k          after kth tree
00044      * \param prob       Total probability array
00045      * \param totalCt    sum of probability array. 
00046      */
00047     template<class WeightIter, class T, class C>
00048     bool after_prediction(WeightIter,  int k, MultiArrayView<2, T, C> const & /* prob */, double /* totalCt */)
00049     {return false;}
00050 };
00051 
00052 
00053 /**Stop predicting after a set number of trees
00054  */
00055 class StopAfterTree : public StopBase
00056 {
00057 public:
00058     double max_tree_p;
00059     int max_tree_;
00060     typedef StopBase SB;
00061     
00062     ArrayVector<double> depths;
00063     
00064     /** Constructor
00065      * \param max_tree number of trees to be used for prediction
00066      */
00067     StopAfterTree(double max_tree)
00068     :
00069         max_tree_p(max_tree)
00070     {}
00071 
00072     template<class T>
00073     void set_external_parameters(ProblemSpec<T> const  &prob, int tree_count = 0, bool is_weighted = false)
00074     {
00075         max_tree_ = ceil(max_tree_p * tree_count);
00076         SB::set_external_parameters(prob, tree_count, is_weighted);
00077     }
00078 
00079     template<class WeightIter, class T, class C>
00080     bool after_prediction(WeightIter,  int k, MultiArrayView<2, T, C> const & /* prob */, double /* totalCt */)
00081     {
00082         if(k == SB::tree_count_ -1)
00083         {
00084                 depths.push_back(double(k+1)/double(SB::tree_count_));
00085                 return false;
00086         }
00087         if(k < max_tree_)
00088            return false;
00089         depths.push_back(double(k+1)/double(SB::tree_count_));
00090         return true;  
00091     }
00092 };
00093 
00094 /** Stop predicting after a certain amount of votes exceed certain proportion.
00095  *  case unweighted voting: stop if the leading class exceeds proportion * SB::tree_count_ 
00096  *  case weighted votion: stop if the leading class exceeds proportion * msample_ * SB::tree_count_ ;
00097  *                          (maximal number of votes possible in both cases)
00098  */
00099 class StopAfterVoteCount : public StopBase
00100 {
00101 public:
00102     double proportion_;
00103     typedef StopBase SB;
00104     ArrayVector<double> depths;
00105 
00106     /** Constructor
00107      * \param proportion specify proportion to be used.
00108      */
00109     StopAfterVoteCount(double proportion)
00110     :
00111         proportion_(proportion)
00112     {}
00113 
00114     template<class WeightIter, class T, class C>
00115     bool after_prediction(WeightIter,  int k, MultiArrayView<2, T, C> const & prob, double /* totalCt */)
00116     {
00117         if(k == SB::tree_count_ -1)
00118         {
00119                 depths.push_back(double(k+1)/double(SB::tree_count_));
00120                 return false;
00121         }
00122 
00123 
00124         if(SB::is_weighted_)
00125         {
00126             if(prob[argMax(prob)] > proportion_ *SB::ext_param_.actual_msample_* SB::tree_count_)
00127             {
00128                 depths.push_back(double(k+1)/double(SB::tree_count_));
00129                 return true;
00130             }
00131         }
00132         else
00133         {
00134             if(prob[argMax(prob)] > proportion_ * SB::tree_count_)
00135             {
00136                 depths.push_back(double(k+1)/double(SB::tree_count_));
00137                 return true;
00138             }
00139         }
00140         return false;
00141     }
00142 
00143 };
00144 
00145 
00146 /** Stop predicting if the 2norm of the probabilities does not change*/
00147 class StopIfConverging : public StopBase
00148 
00149 {
00150 public:
00151     double thresh_;
00152     int num_;
00153     MultiArray<2, double> last_;
00154     MultiArray<2, double> cur_;
00155     ArrayVector<double> depths;
00156     typedef StopBase SB;
00157 
00158     /** Constructor
00159      * \param thresh: If the two norm of the probabilites changes less then thresh then stop
00160      * \param num   : look at atleast num trees before stopping
00161      */
00162     StopIfConverging(double thresh, int num = 10)
00163     :
00164         thresh_(thresh), 
00165         num_(num)
00166     {}
00167 
00168     template<class T>
00169     void set_external_parameters(ProblemSpec<T> const  &prob, int tree_count = 0, bool is_weighted = false)
00170     {
00171         last_.reshape(MultiArrayShape<2>::type(1, prob.class_count_), 0);
00172         cur_.reshape(MultiArrayShape<2>::type(1, prob.class_count_), 0);
00173         SB::set_external_parameters(prob, tree_count, is_weighted);
00174     }
00175     template<class WeightIter, class T, class C>
00176     bool after_prediction(WeightIter iter,  int k, MultiArrayView<2, T, C> const & prob, double totalCt)
00177     {
00178         if(k == SB::tree_count_ -1)
00179         {
00180                 depths.push_back(double(k+1)/double(SB::tree_count_));
00181                 return false;
00182         }
00183         if(k <= num_)
00184         {
00185             last_ = prob;
00186             last_/= last_.norm(1);
00187             return false;
00188         }
00189         else 
00190         {
00191             cur_ = prob;
00192             cur_ /= cur_.norm(1);
00193             last_ -= cur_;
00194             double nrm = last_.norm(); 
00195             if(nrm < thresh_)
00196             {
00197                 depths.push_back(double(k+1)/double(SB::tree_count_));
00198                 return true;
00199             }
00200             else
00201             {
00202                 last_ = cur_;
00203             }
00204         }
00205         return false;
00206     }
00207 };
00208 
00209 
00210 /** Stop predicting if the margin prob(leading class) - prob(second class) exceeds a proportion
00211  *  case unweighted voting: stop if margin exceeds proportion * SB::tree_count_ 
00212  *  case weighted votion: stop if margin exceeds proportion * msample_ * SB::tree_count_ ;
00213  *                          (maximal number of votes possible in both cases)
00214  */
00215 class StopIfMargin : public StopBase  
00216 {
00217 public:
00218     double proportion_;
00219     typedef StopBase SB;
00220     ArrayVector<double> depths;
00221 
00222     /** Constructor
00223      * \param proportion specify proportion to be used.
00224      */
00225     StopIfMargin(double proportion)
00226     :
00227         proportion_(proportion)
00228     {}
00229 
00230     template<class WeightIter, class T, class C>
00231     bool after_prediction(WeightIter,  int k, MultiArrayView<2, T, C> prob, double /* totalCt */)
00232     {
00233         if(k == SB::tree_count_ -1)
00234         {
00235                 depths.push_back(double(k+1)/double(SB::tree_count_));
00236                 return false;
00237         }
00238         int index = argMax(prob);
00239         double a = prob[argMax(prob)];
00240         prob[argMax(prob)] = 0;
00241         double b = prob[argMax(prob)];
00242         prob[index] = a; 
00243         double margin = a - b;
00244         if(SB::is_weighted_)
00245         {
00246             if(margin > proportion_ *SB::ext_param_.actual_msample_ * SB::tree_count_)
00247             {
00248                 depths.push_back(double(k+1)/double(SB::tree_count_));
00249                 return true;
00250             }
00251         }
00252         else
00253         {
00254             if(prob[argMax(prob)] > proportion_ * SB::tree_count_)
00255             {
00256                 depths.push_back(double(k+1)/double(SB::tree_count_));
00257                 return true;
00258             }
00259         }
00260         return false;
00261     }
00262 };
00263 
00264 
00265 /**Probabilistic Stopping criterion (binomial test)
00266  *
00267  * Can only be used in a two class setting
00268  *
00269  * Stop if the Parameters estimated for the underlying binomial distribution
00270  * can be estimated with certainty over 1-alpha.
00271  * (Thesis, Rahul Nair Page 80 onwards: called the "binomial" criterion
00272  */
00273 class StopIfBinTest : public StopBase  
00274 {
00275 public:
00276     double alpha_;  
00277     MultiArrayView<2, double> n_choose_k;
00278     /** Constructor
00279      * \param proportion specify alpha value for binomial test.
00280      * \param nck_ Matrix with precomputed values for n choose k
00281      * nck_(n, k) is n choose k. 
00282      */
00283     StopIfBinTest(double alpha, MultiArrayView<2, double> nck_)
00284     :
00285         alpha_(alpha),
00286         n_choose_k(nck_)
00287     {}
00288     typedef StopBase SB;
00289     
00290     /**ArrayVector that will contain the fraction of trees that was visited before terminating
00291      */
00292     ArrayVector<double> depths;
00293 
00294     double binomial(int N, int k, double p)
00295     {
00296 //        return n_choose_k(N, k) * es_detail::power(p, k) *es_detail::power(1 - p, N-k);
00297         return n_choose_k(N, k) * std::pow(p, k) * std::pow(1 - p, N-k);
00298     }
00299 
00300     template<class WeightIter, class T, class C>
00301     bool after_prediction(WeightIter iter,  int k, MultiArrayView<2, T, C> prob, double totalCt)
00302     {
00303         if(k == SB::tree_count_ -1)
00304         {
00305                 depths.push_back(double(k+1)/double(SB::tree_count_));
00306                 return false;
00307         }
00308         if(k < 10)
00309         {
00310             return false;
00311         }
00312         int index = argMax(prob);
00313         int n_a  = prob[index];
00314         int n_b  = prob[(index+1)%2];
00315         int n_tilde = (SB::tree_count_ - n_a + n_b);
00316         double p_a = double(n_b - n_a + n_tilde)/double(2* n_tilde);
00317         vigra_precondition(p_a <= 1, "probability should be smaller than 1");
00318         double cum_val = 0;
00319         int c = 0; 
00320   //      std::cerr << "prob: " << p_a << std::endl;
00321         if(n_a <= 0)n_a = 0;
00322         if(n_b <= 0)n_b = 0;
00323         for(int ii = 0; ii <= n_b + n_a;++ii)
00324         {
00325 //            std::cerr << "nb +ba " << n_b + n_a << " " << ii <<std::endl;
00326             cum_val += binomial(n_b + n_a, ii, p_a); 
00327             if(cum_val >= 1 -alpha_)
00328             {
00329                 c = ii;
00330                 break;
00331             }
00332         }
00333 //        std::cerr << c << " " << n_a << " " << n_b << " " << p_a <<   alpha_ << std::endl;
00334         if(c < n_a)
00335         {
00336             depths.push_back(double(k+1)/double(SB::tree_count_));
00337             return true;
00338         }
00339 
00340         return false;
00341     }
00342 };
00343 
00344 /**Probabilistic Stopping criteria. (toChange)
00345  *
00346  * Can only be used in a two class setting
00347  *
00348  * Stop if the probability that the decision will change after seeing all trees falls under
00349  * a specified value alpha.
00350  * (Thesis, Rahul Nair Page 80 onwards: called the "toChange" criterion
00351  */
00352 class StopIfProb : public StopBase  
00353 {
00354 public:
00355     double alpha_;  
00356     MultiArrayView<2, double> n_choose_k;
00357     
00358     
00359     /** Constructor
00360      * \param proportion specify alpha value
00361      * \param nck_ Matrix with precomputed values for n choose k
00362      * nck_(n, k) is n choose k. 
00363      */
00364     StopIfProb(double alpha, MultiArrayView<2, double> nck_)
00365     :
00366         alpha_(alpha),
00367         n_choose_k(nck_)
00368     {}
00369     typedef StopBase SB;
00370     /**ArrayVector that will contain the fraction of trees that was visited before terminating
00371      */
00372     ArrayVector<double> depths;
00373 
00374     double binomial(int N, int k, double p)
00375     {
00376 //        return n_choose_k(N, k) * es_detail::power(p, k) *es_detail::power(1 - p, N-k);
00377         return n_choose_k(N, k) * std::pow(p, k) * std::pow(1 - p, N-k);
00378     }
00379 
00380     template<class WeightIter, class T, class C>
00381     bool after_prediction(WeightIter iter,  int k, MultiArrayView<2, T, C> prob, double totalCt)
00382     {
00383         if(k == SB::tree_count_ -1)
00384         {
00385                 depths.push_back(double(k+1)/double(SB::tree_count_));
00386                 return false;
00387         }
00388         if(k <= 10)
00389         {
00390             return false;
00391         }
00392         int index = argMax(prob);
00393         int n_a  = prob[index];
00394         int n_b  = prob[(index+1)%2];
00395         int n_needed = ceil(double(SB::tree_count_)/2.0)-n_a;
00396         int n_tilde = SB::tree_count_ - (n_a +n_b);
00397         if(n_tilde <= 0) n_tilde = 0;
00398         if(n_needed <= 0) n_needed = 0;
00399         double p = 0;
00400         for(int ii = n_needed; ii < n_tilde; ++ii)
00401             p += binomial(n_tilde, ii, 0.5);
00402         
00403         if(p >= 1-alpha_)
00404         {
00405             depths.push_back(double(k+1)/double(SB::tree_count_));
00406             return true;
00407         }
00408 
00409         return false;
00410     }
00411 };
00412 } //namespace vigra;
00413 #endif //RF_EARLY_STOPPING_P_HXX

© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de)
Heidelberg Collaboratory for Image Processing, University of Heidelberg, Germany

html generated using doxygen and Python
vigra 1.7.1 (Sun Feb 19 2012)