[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]

vigra/random_forest/rf_earlystopping.hxx
00001 #ifndef RF_EARLY_STOPPING_P_HXX
00002 #define RF_EARLY_STOPPING_P_HXX
00003 #include <cmath>
00004 #include "rf_common.hxx"
00005 
00006 namespace vigra
00007 {
00008 
00009 #if 0    
00010 namespace es_detail
00011 {
00012     template<class T>
00013     T power(T const & in, int n)
00014     {
00015         T result = NumericTraits<T>::one();
00016         for(int ii = 0; ii < n ;++ii)
00017             result *= in;
00018         return result;
00019     }
00020 }
00021 #endif
00022 class StopBase
00023 {
00024 protected:
00025     ProblemSpec<> ext_param_;
00026     int tree_count_ ;
00027     bool is_weighted_;
00028 
00029 public:
00030     template<class T>
00031     void set_external_parameters(ProblemSpec<T> const  &prob, int tree_count = 0, bool is_weighted = false)
00032     {
00033         ext_param_ = prob; 
00034         is_weighted_ = is_weighted;
00035         tree_count_ = tree_count;
00036     }
00037 };
00038 
00039 
00040 /**Stop predicting after a set number of trees
00041  */
00042 class StopAfterTree : public StopBase
00043 {
00044 public:
00045     double max_tree_p;
00046     int max_tree_;
00047     typedef StopBase SB;
00048     
00049     ArrayVector<double> depths;
00050     StopAfterTree(double max_tree)
00051     :
00052         max_tree_p(max_tree)
00053     {}
00054 
00055     template<class T>
00056     void set_external_parameters(ProblemSpec<T> const  &prob, int tree_count = 0, bool is_weighted = false)
00057     {
00058         max_tree_ = ceil(max_tree_p * tree_count);
00059         SB::set_external_parameters(prob, tree_count, is_weighted);
00060     }
00061 
00062     template<class WeightIter, class T, class C>
00063     bool after_prediction(WeightIter,  int k, MultiArrayView<2, T, C> const & /* prob */, double /* totalCt */)
00064     {
00065         if(k == SB::tree_count_ -1)
00066         {
00067                 depths.push_back(double(k+1)/double(SB::tree_count_));
00068                 return false;
00069         }
00070         if(k < max_tree_)
00071            return false;
00072         depths.push_back(double(k+1)/double(SB::tree_count_));
00073         return true;  
00074     }
00075 };
00076 
00077 /** Stop predicting after a certain amount of votes exceed certain proportion.
00078  *  case unweighted voting: stop if the leading class exceeds proportion * SB::tree_count_ 
00079  *  case weighted votion: stop if the leading class exceeds proportion * msample_ * SB::tree_count_ ;
00080  *                          (maximal number of votes possible in both cases)
00081  */
00082 class StopAfterVoteCount : public StopBase
00083 {
00084 public:
00085     double proportion_;
00086     typedef StopBase SB;
00087     ArrayVector<double> depths;
00088     StopAfterVoteCount(double proportion)
00089     :
00090         proportion_(proportion)
00091     {}
00092 
00093     template<class WeightIter, class T, class C>
00094     bool after_prediction(WeightIter,  int k, MultiArrayView<2, T, C> const & prob, double /* totalCt */)
00095     {
00096         if(k == SB::tree_count_ -1)
00097         {
00098                 depths.push_back(double(k+1)/double(SB::tree_count_));
00099                 return false;
00100         }
00101 
00102 
00103         if(SB::is_weighted_)
00104         {
00105             if(prob[argMax(prob)] > proportion_ *SB::ext_param_.actual_msample_* SB::tree_count_)
00106             {
00107                 depths.push_back(double(k+1)/double(SB::tree_count_));
00108                 return true;
00109             }
00110         }
00111         else
00112         {
00113             if(prob[argMax(prob)] > proportion_ * SB::tree_count_)
00114             {
00115                 depths.push_back(double(k+1)/double(SB::tree_count_));
00116                 return true;
00117             }
00118         }
00119         return false;
00120     }
00121 
00122 };
00123 
00124 
00125 /** Stop predicting if the 2norm of the probabilities does not change*/
00126 class StopIfConverging : public StopBase
00127 
00128 {
00129 public:
00130     double thresh_;
00131     int num_;
00132     MultiArray<2, double> last_;
00133     MultiArray<2, double> cur_;
00134     ArrayVector<double> depths;
00135     typedef StopBase SB;
00136 
00137     StopIfConverging(double thresh, int num = 10)
00138     :
00139         thresh_(thresh), 
00140         num_(num)
00141     {}
00142 
00143     template<class T>
00144     void set_external_parameters(ProblemSpec<T> const  &prob, int tree_count = 0, bool is_weighted = false)
00145     {
00146         last_.reshape(MultiArrayShape<2>::type(1, prob.class_count_), 0);
00147         cur_.reshape(MultiArrayShape<2>::type(1, prob.class_count_), 0);
00148         SB::set_external_parameters(prob, tree_count, is_weighted);
00149     }
00150     template<class WeightIter, class T, class C>
00151     bool after_prediction(WeightIter iter,  int k, MultiArrayView<2, T, C> const & prob, double totalCt)
00152     {
00153         if(k == SB::tree_count_ -1)
00154         {
00155                 depths.push_back(double(k+1)/double(SB::tree_count_));
00156                 return false;
00157         }
00158         if(k <= num_)
00159         {
00160             last_ = prob;
00161             last_/= last_.norm(1);
00162             return false;
00163         }
00164         else 
00165         {
00166             cur_ = prob;
00167             cur_ /= cur_.norm(1);
00168             last_ -= cur_;
00169             double nrm = last_.norm(); 
00170             if(nrm < thresh_)
00171             {
00172                 depths.push_back(double(k+1)/double(SB::tree_count_));
00173                 return true;
00174             }
00175             else
00176             {
00177                 last_ = cur_;
00178             }
00179         }
00180         return false;
00181     }
00182 };
00183 
00184 
00185 /** Stop predicting if the margin prob(leading class) - prob(second class) exceeds a proportion
00186  *  case unweighted voting: stop if margin exceeds proportion * SB::tree_count_ 
00187  *  case weighted votion: stop if margin exceeds proportion * msample_ * SB::tree_count_ ;
00188  *                          (maximal number of votes possible in both cases)
00189  */
00190 class StopIfMargin : public StopBase  
00191 {
00192 public:
00193     double proportion_;
00194     typedef StopBase SB;
00195     ArrayVector<double> depths;
00196 
00197     StopIfMargin(double proportion)
00198     :
00199         proportion_(proportion)
00200     {}
00201 
00202     template<class WeightIter, class T, class C>
00203     bool after_prediction(WeightIter,  int k, MultiArrayView<2, T, C> prob, double /* totalCt */)
00204     {
00205         if(k == SB::tree_count_ -1)
00206         {
00207                 depths.push_back(double(k+1)/double(SB::tree_count_));
00208                 return false;
00209         }
00210         int index = argMax(prob);
00211         double a = prob[argMax(prob)];
00212         prob[argMax(prob)] = 0;
00213         double b = prob[argMax(prob)];
00214         prob[index] = a; 
00215         double margin = a - b;
00216         if(SB::is_weighted_)
00217         {
00218             if(margin > proportion_ *SB::ext_param_.actual_msample_ * SB::tree_count_)
00219             {
00220                 depths.push_back(double(k+1)/double(SB::tree_count_));
00221                 return true;
00222             }
00223         }
00224         else
00225         {
00226             if(prob[argMax(prob)] > proportion_ * SB::tree_count_)
00227             {
00228                 depths.push_back(double(k+1)/double(SB::tree_count_));
00229                 return true;
00230             }
00231         }
00232         return false;
00233     }
00234 };
00235 
00236 class StopIfBinTest : public StopBase  
00237 {
00238 public:
00239     double alpha_;  
00240     MultiArrayView<2, double> n_choose_k;
00241     StopIfBinTest(double alpha, MultiArrayView<2, double> nck_)
00242     :
00243         alpha_(alpha),
00244         n_choose_k(nck_)
00245     {}
00246     typedef StopBase SB;
00247     ArrayVector<double> depths;
00248 
00249     double binomial(int N, int k, double p)
00250     {
00251 //        return n_choose_k(N, k) * es_detail::power(p, k) *es_detail::power(1 - p, N-k);
00252         return n_choose_k(N, k) * std::pow(p, k) * std::pow(1 - p, N-k);
00253     }
00254 
00255     template<class WeightIter, class T, class C>
00256     bool after_prediction(WeightIter iter,  int k, MultiArrayView<2, T, C> prob, double totalCt)
00257     {
00258         if(k == SB::tree_count_ -1)
00259         {
00260                 depths.push_back(double(k+1)/double(SB::tree_count_));
00261                 return false;
00262         }
00263         if(k < 10)
00264         {
00265             return false;
00266         }
00267         int index = argMax(prob);
00268         int n_a  = prob[index];
00269         int n_b  = prob[(index+1)%2];
00270         int n_tilde = (SB::tree_count_ - n_a + n_b);
00271         double p_a = double(n_b - n_a + n_tilde)/double(2* n_tilde);
00272         vigra_precondition(p_a <= 1, "probability should be smaller than 1");
00273         double cum_val = 0;
00274         int c = 0; 
00275   //      std::cerr << "prob: " << p_a << std::endl;
00276         if(n_a <= 0)n_a = 0;
00277         if(n_b <= 0)n_b = 0;
00278         for(int ii = 0; ii <= n_b + n_a;++ii)
00279         {
00280 //            std::cerr << "nb +ba " << n_b + n_a << " " << ii <<std::endl;
00281             cum_val += binomial(n_b + n_a, ii, p_a); 
00282             if(cum_val >= 1 -alpha_)
00283             {
00284                 c = ii;
00285                 break;
00286             }
00287         }
00288 //        std::cerr << c << " " << n_a << " " << n_b << " " << p_a <<   alpha_ << std::endl;
00289         if(c < n_a)
00290         {
00291             depths.push_back(double(k+1)/double(SB::tree_count_));
00292             return true;
00293         }
00294 
00295         return false;
00296     }
00297 };
00298 
00299 
00300 class StopIfProb : public StopBase  
00301 {
00302 public:
00303     double alpha_;  
00304     MultiArrayView<2, double> n_choose_k;
00305     StopIfProb(double alpha, MultiArrayView<2, double> nck_)
00306     :
00307         alpha_(alpha),
00308         n_choose_k(nck_)
00309     {}
00310     typedef StopBase SB;
00311     ArrayVector<double> depths;
00312 
00313     double binomial(int N, int k, double p)
00314     {
00315 //        return n_choose_k(N, k) * es_detail::power(p, k) *es_detail::power(1 - p, N-k);
00316         return n_choose_k(N, k) * std::pow(p, k) * std::pow(1 - p, N-k);
00317     }
00318 
00319     template<class WeightIter, class T, class C>
00320     bool after_prediction(WeightIter iter,  int k, MultiArrayView<2, T, C> prob, double totalCt)
00321     {
00322         if(k == SB::tree_count_ -1)
00323         {
00324                 depths.push_back(double(k+1)/double(SB::tree_count_));
00325                 return false;
00326         }
00327         if(k <= 10)
00328         {
00329             return false;
00330         }
00331         int index = argMax(prob);
00332         int n_a  = prob[index];
00333         int n_b  = prob[(index+1)%2];
00334         int n_needed = ceil(double(SB::tree_count_)/2.0)-n_a;
00335         int n_tilde = SB::tree_count_ - (n_a +n_b);
00336         if(n_tilde <= 0) n_tilde = 0;
00337         if(n_needed <= 0) n_needed = 0;
00338         double p = 0;
00339         for(int ii = n_needed; ii < n_tilde; ++ii)
00340             p += binomial(n_tilde, ii, 0.5);
00341         
00342         if(p >= 1-alpha_)
00343         {
00344             depths.push_back(double(k+1)/double(SB::tree_count_));
00345             return true;
00346         }
00347 
00348         return false;
00349     }
00350 };
00351 } //namespace vigra;
00352 #endif //RF_EARLY_STOPPING_P_HXX

© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de)
Heidelberg Collaboratory for Image Processing, University of Heidelberg, Germany

html generated using doxygen and Python
vigra 1.7.0 (Thu Aug 25 2011)