[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]
00001 #ifndef RF_EARLY_STOPPING_P_HXX 00002 #define RF_EARLY_STOPPING_P_HXX 00003 #include <cmath> 00004 #include "rf_common.hxx" 00005 00006 namespace vigra 00007 { 00008 00009 #if 0 00010 namespace es_detail 00011 { 00012 template<class T> 00013 T power(T const & in, int n) 00014 { 00015 T result = NumericTraits<T>::one(); 00016 for(int ii = 0; ii < n ;++ii) 00017 result *= in; 00018 return result; 00019 } 00020 } 00021 #endif 00022 class StopBase 00023 { 00024 protected: 00025 ProblemSpec<> ext_param_; 00026 int tree_count_ ; 00027 bool is_weighted_; 00028 00029 public: 00030 template<class T> 00031 void set_external_parameters(ProblemSpec<T> const &prob, int tree_count = 0, bool is_weighted = false) 00032 { 00033 ext_param_ = prob; 00034 is_weighted_ = is_weighted; 00035 tree_count_ = tree_count; 00036 } 00037 }; 00038 00039 00040 /**Stop predicting after a set number of trees 00041 */ 00042 class StopAfterTree : public StopBase 00043 { 00044 public: 00045 double max_tree_p; 00046 int max_tree_; 00047 typedef StopBase SB; 00048 00049 ArrayVector<double> depths; 00050 StopAfterTree(double max_tree) 00051 : 00052 max_tree_p(max_tree) 00053 {} 00054 00055 template<class T> 00056 void set_external_parameters(ProblemSpec<T> const &prob, int tree_count = 0, bool is_weighted = false) 00057 { 00058 max_tree_ = ceil(max_tree_p * tree_count); 00059 SB::set_external_parameters(prob, tree_count, is_weighted); 00060 } 00061 00062 template<class WeightIter, class T, class C> 00063 bool after_prediction(WeightIter, int k, MultiArrayView<2, T, C> const & /* prob */, double /* totalCt */) 00064 { 00065 if(k == SB::tree_count_ -1) 00066 { 00067 depths.push_back(double(k+1)/double(SB::tree_count_)); 00068 return false; 00069 } 00070 if(k < max_tree_) 00071 return false; 00072 depths.push_back(double(k+1)/double(SB::tree_count_)); 00073 return true; 00074 } 00075 }; 00076 00077 /** Stop predicting after a certain amount of votes exceed certain proportion. 00078 * case unweighted voting: stop if the leading class exceeds proportion * SB::tree_count_ 00079 * case weighted votion: stop if the leading class exceeds proportion * msample_ * SB::tree_count_ ; 00080 * (maximal number of votes possible in both cases) 00081 */ 00082 class StopAfterVoteCount : public StopBase 00083 { 00084 public: 00085 double proportion_; 00086 typedef StopBase SB; 00087 ArrayVector<double> depths; 00088 StopAfterVoteCount(double proportion) 00089 : 00090 proportion_(proportion) 00091 {} 00092 00093 template<class WeightIter, class T, class C> 00094 bool after_prediction(WeightIter, int k, MultiArrayView<2, T, C> const & prob, double /* totalCt */) 00095 { 00096 if(k == SB::tree_count_ -1) 00097 { 00098 depths.push_back(double(k+1)/double(SB::tree_count_)); 00099 return false; 00100 } 00101 00102 00103 if(SB::is_weighted_) 00104 { 00105 if(prob[argMax(prob)] > proportion_ *SB::ext_param_.actual_msample_* SB::tree_count_) 00106 { 00107 depths.push_back(double(k+1)/double(SB::tree_count_)); 00108 return true; 00109 } 00110 } 00111 else 00112 { 00113 if(prob[argMax(prob)] > proportion_ * SB::tree_count_) 00114 { 00115 depths.push_back(double(k+1)/double(SB::tree_count_)); 00116 return true; 00117 } 00118 } 00119 return false; 00120 } 00121 00122 }; 00123 00124 00125 /** Stop predicting if the 2norm of the probabilities does not change*/ 00126 class StopIfConverging : public StopBase 00127 00128 { 00129 public: 00130 double thresh_; 00131 int num_; 00132 MultiArray<2, double> last_; 00133 MultiArray<2, double> cur_; 00134 ArrayVector<double> depths; 00135 typedef StopBase SB; 00136 00137 StopIfConverging(double thresh, int num = 10) 00138 : 00139 thresh_(thresh), 00140 num_(num) 00141 {} 00142 00143 template<class T> 00144 void set_external_parameters(ProblemSpec<T> const &prob, int tree_count = 0, bool is_weighted = false) 00145 { 00146 last_.reshape(MultiArrayShape<2>::type(1, prob.class_count_), 0); 00147 cur_.reshape(MultiArrayShape<2>::type(1, prob.class_count_), 0); 00148 SB::set_external_parameters(prob, tree_count, is_weighted); 00149 } 00150 template<class WeightIter, class T, class C> 00151 bool after_prediction(WeightIter iter, int k, MultiArrayView<2, T, C> const & prob, double totalCt) 00152 { 00153 if(k == SB::tree_count_ -1) 00154 { 00155 depths.push_back(double(k+1)/double(SB::tree_count_)); 00156 return false; 00157 } 00158 if(k <= num_) 00159 { 00160 last_ = prob; 00161 last_/= last_.norm(1); 00162 return false; 00163 } 00164 else 00165 { 00166 cur_ = prob; 00167 cur_ /= cur_.norm(1); 00168 last_ -= cur_; 00169 double nrm = last_.norm(); 00170 if(nrm < thresh_) 00171 { 00172 depths.push_back(double(k+1)/double(SB::tree_count_)); 00173 return true; 00174 } 00175 else 00176 { 00177 last_ = cur_; 00178 } 00179 } 00180 return false; 00181 } 00182 }; 00183 00184 00185 /** Stop predicting if the margin prob(leading class) - prob(second class) exceeds a proportion 00186 * case unweighted voting: stop if margin exceeds proportion * SB::tree_count_ 00187 * case weighted votion: stop if margin exceeds proportion * msample_ * SB::tree_count_ ; 00188 * (maximal number of votes possible in both cases) 00189 */ 00190 class StopIfMargin : public StopBase 00191 { 00192 public: 00193 double proportion_; 00194 typedef StopBase SB; 00195 ArrayVector<double> depths; 00196 00197 StopIfMargin(double proportion) 00198 : 00199 proportion_(proportion) 00200 {} 00201 00202 template<class WeightIter, class T, class C> 00203 bool after_prediction(WeightIter, int k, MultiArrayView<2, T, C> prob, double /* totalCt */) 00204 { 00205 if(k == SB::tree_count_ -1) 00206 { 00207 depths.push_back(double(k+1)/double(SB::tree_count_)); 00208 return false; 00209 } 00210 int index = argMax(prob); 00211 double a = prob[argMax(prob)]; 00212 prob[argMax(prob)] = 0; 00213 double b = prob[argMax(prob)]; 00214 prob[index] = a; 00215 double margin = a - b; 00216 if(SB::is_weighted_) 00217 { 00218 if(margin > proportion_ *SB::ext_param_.actual_msample_ * SB::tree_count_) 00219 { 00220 depths.push_back(double(k+1)/double(SB::tree_count_)); 00221 return true; 00222 } 00223 } 00224 else 00225 { 00226 if(prob[argMax(prob)] > proportion_ * SB::tree_count_) 00227 { 00228 depths.push_back(double(k+1)/double(SB::tree_count_)); 00229 return true; 00230 } 00231 } 00232 return false; 00233 } 00234 }; 00235 00236 class StopIfBinTest : public StopBase 00237 { 00238 public: 00239 double alpha_; 00240 MultiArrayView<2, double> n_choose_k; 00241 StopIfBinTest(double alpha, MultiArrayView<2, double> nck_) 00242 : 00243 alpha_(alpha), 00244 n_choose_k(nck_) 00245 {} 00246 typedef StopBase SB; 00247 ArrayVector<double> depths; 00248 00249 double binomial(int N, int k, double p) 00250 { 00251 // return n_choose_k(N, k) * es_detail::power(p, k) *es_detail::power(1 - p, N-k); 00252 return n_choose_k(N, k) * std::pow(p, k) * std::pow(1 - p, N-k); 00253 } 00254 00255 template<class WeightIter, class T, class C> 00256 bool after_prediction(WeightIter iter, int k, MultiArrayView<2, T, C> prob, double totalCt) 00257 { 00258 if(k == SB::tree_count_ -1) 00259 { 00260 depths.push_back(double(k+1)/double(SB::tree_count_)); 00261 return false; 00262 } 00263 if(k < 10) 00264 { 00265 return false; 00266 } 00267 int index = argMax(prob); 00268 int n_a = prob[index]; 00269 int n_b = prob[(index+1)%2]; 00270 int n_tilde = (SB::tree_count_ - n_a + n_b); 00271 double p_a = double(n_b - n_a + n_tilde)/double(2* n_tilde); 00272 vigra_precondition(p_a <= 1, "probability should be smaller than 1"); 00273 double cum_val = 0; 00274 int c = 0; 00275 // std::cerr << "prob: " << p_a << std::endl; 00276 if(n_a <= 0)n_a = 0; 00277 if(n_b <= 0)n_b = 0; 00278 for(int ii = 0; ii <= n_b + n_a;++ii) 00279 { 00280 // std::cerr << "nb +ba " << n_b + n_a << " " << ii <<std::endl; 00281 cum_val += binomial(n_b + n_a, ii, p_a); 00282 if(cum_val >= 1 -alpha_) 00283 { 00284 c = ii; 00285 break; 00286 } 00287 } 00288 // std::cerr << c << " " << n_a << " " << n_b << " " << p_a << alpha_ << std::endl; 00289 if(c < n_a) 00290 { 00291 depths.push_back(double(k+1)/double(SB::tree_count_)); 00292 return true; 00293 } 00294 00295 return false; 00296 } 00297 }; 00298 00299 00300 class StopIfProb : public StopBase 00301 { 00302 public: 00303 double alpha_; 00304 MultiArrayView<2, double> n_choose_k; 00305 StopIfProb(double alpha, MultiArrayView<2, double> nck_) 00306 : 00307 alpha_(alpha), 00308 n_choose_k(nck_) 00309 {} 00310 typedef StopBase SB; 00311 ArrayVector<double> depths; 00312 00313 double binomial(int N, int k, double p) 00314 { 00315 // return n_choose_k(N, k) * es_detail::power(p, k) *es_detail::power(1 - p, N-k); 00316 return n_choose_k(N, k) * std::pow(p, k) * std::pow(1 - p, N-k); 00317 } 00318 00319 template<class WeightIter, class T, class C> 00320 bool after_prediction(WeightIter iter, int k, MultiArrayView<2, T, C> prob, double totalCt) 00321 { 00322 if(k == SB::tree_count_ -1) 00323 { 00324 depths.push_back(double(k+1)/double(SB::tree_count_)); 00325 return false; 00326 } 00327 if(k <= 10) 00328 { 00329 return false; 00330 } 00331 int index = argMax(prob); 00332 int n_a = prob[index]; 00333 int n_b = prob[(index+1)%2]; 00334 int n_needed = ceil(double(SB::tree_count_)/2.0)-n_a; 00335 int n_tilde = SB::tree_count_ - (n_a +n_b); 00336 if(n_tilde <= 0) n_tilde = 0; 00337 if(n_needed <= 0) n_needed = 0; 00338 double p = 0; 00339 for(int ii = n_needed; ii < n_tilde; ++ii) 00340 p += binomial(n_tilde, ii, 0.5); 00341 00342 if(p >= 1-alpha_) 00343 { 00344 depths.push_back(double(k+1)/double(SB::tree_count_)); 00345 return true; 00346 } 00347 00348 return false; 00349 } 00350 }; 00351 } //namespace vigra; 00352 #endif //RF_EARLY_STOPPING_P_HXX
© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de) |
html generated using doxygen and Python
|