[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]
00001 /************************************************************************/ 00002 /* */ 00003 /* Copyright 2008-2009 by Ullrich Koethe and Rahul Nair */ 00004 /* */ 00005 /* This file is part of the VIGRA computer vision library. */ 00006 /* The VIGRA Website is */ 00007 /* http://hci.iwr.uni-heidelberg.de/vigra/ */ 00008 /* Please direct questions, bug reports, and contributions to */ 00009 /* ullrich.koethe@iwr.uni-heidelberg.de or */ 00010 /* vigra@informatik.uni-hamburg.de */ 00011 /* */ 00012 /* Permission is hereby granted, free of charge, to any person */ 00013 /* obtaining a copy of this software and associated documentation */ 00014 /* files (the "Software"), to deal in the Software without */ 00015 /* restriction, including without limitation the rights to use, */ 00016 /* copy, modify, merge, publish, distribute, sublicense, and/or */ 00017 /* sell copies of the Software, and to permit persons to whom the */ 00018 /* Software is furnished to do so, subject to the following */ 00019 /* conditions: */ 00020 /* */ 00021 /* The above copyright notice and this permission notice shall be */ 00022 /* included in all copies or substantial portions of the */ 00023 /* Software. */ 00024 /* */ 00025 /* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND */ 00026 /* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES */ 00027 /* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND */ 00028 /* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT */ 00029 /* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, */ 00030 /* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING */ 00031 /* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR */ 00032 /* OTHER DEALINGS IN THE SOFTWARE. */ 00033 /* */ 00034 /************************************************************************/ 00035 #ifndef VIGRA_RANDOM_FOREST_SPLIT_HXX 00036 #define VIGRA_RANDOM_FOREST_SPLIT_HXX 00037 #include <algorithm> 00038 #include <map> 00039 #include <numeric> 00040 #include "../mathutil.hxx" 00041 #include "../array_vector.hxx" 00042 #include "../sized_int.hxx" 00043 #include "../matrix.hxx" 00044 #include "../random.hxx" 00045 #include "../functorexpression.hxx" 00046 #include "rf_nodeproxy.hxx" 00047 #include "rf_sampling.hxx" 00048 //#include "../hokashyap.hxx" 00049 //#include "vigra/rf_helpers.hxx" 00050 00051 namespace vigra 00052 { 00053 00054 // Incomplete Class to ensure that findBestSplit is always implemented in 00055 // the derived classes of SplitBase 00056 class CompileTimeError; 00057 00058 00059 namespace detail 00060 { 00061 template<class Tag> 00062 class Normalise 00063 { 00064 public: 00065 template<class Iter> 00066 static void exec(Iter begin, Iter end) 00067 {} 00068 }; 00069 00070 template<> 00071 class Normalise<ClassificationTag> 00072 { 00073 public: 00074 template<class Iter> 00075 static void exec (Iter begin, Iter end) 00076 { 00077 double bla = std::accumulate(begin, end, 0.0); 00078 for(int ii = 0; ii < end - begin; ++ii) 00079 begin[ii] = begin[ii]/bla ; 00080 } 00081 }; 00082 } 00083 00084 00085 /** Base Class for all SplitFunctors used with the \ref RandomForest class 00086 defines the interface used while learning a tree. 00087 **/ 00088 template<class Tag> 00089 class SplitBase 00090 { 00091 public: 00092 00093 typedef Tag RF_Tag; 00094 typedef DT_StackEntry<ArrayVectorView<Int32>::iterator> 00095 StackEntry_t; 00096 00097 ProblemSpec<> ext_param_; 00098 00099 NodeBase::T_Container_type t_data; 00100 NodeBase::P_Container_type p_data; 00101 00102 NodeBase node_; 00103 00104 /** returns the DecisionTree Node created by 00105 \ref findBestSplit or \ref makeTerminalNode. 00106 **/ 00107 00108 template<class T> 00109 void set_external_parameters(ProblemSpec<T> const & in) 00110 { 00111 ext_param_ = in; 00112 t_data.push_back(in.column_count_); 00113 t_data.push_back(in.class_count_); 00114 } 00115 00116 NodeBase & createNode() 00117 { 00118 return node_; 00119 } 00120 00121 int classCount() const 00122 { 00123 return int(t_data[1]); 00124 } 00125 00126 int featureCount() const 00127 { 00128 return int(t_data[0]); 00129 } 00130 00131 /** resets internal data. Should always be called before 00132 calling findBestSplit or makeTerminalNode 00133 **/ 00134 void reset() 00135 { 00136 t_data.resize(2); 00137 p_data.resize(0); 00138 } 00139 00140 00141 /** findBestSplit has to be implemented in derived split functor. 00142 these functions only insures That a CompileTime error is issued 00143 if no such method was defined. 00144 **/ 00145 00146 template<class T, class C, class T2, class C2, class Region, class Random> 00147 int findBestSplit(MultiArrayView<2, T, C> features, 00148 MultiArrayView<2, T2, C2> labels, 00149 Region region, 00150 ArrayVector<Region> childs, 00151 Random randint) 00152 { 00153 CompileTimeError SplitFunctor__findBestSplit_member_was_not_defined; 00154 return 0; 00155 } 00156 00157 /** default action for creating a terminal Node. 00158 sets the Class probability of the remaining region according to 00159 the class histogram 00160 **/ 00161 template<class T, class C, class T2,class C2, class Region, class Random> 00162 int makeTerminalNode(MultiArrayView<2, T, C> features, 00163 MultiArrayView<2, T2, C2> labels, 00164 Region & region, 00165 Random randint) 00166 { 00167 Node<e_ConstProbNode> ret(t_data, p_data); 00168 node_ = ret; 00169 if(ext_param_.class_weights_.size() != region.classCounts().size()) 00170 { 00171 std::copy( region.classCounts().begin(), 00172 region.classCounts().end(), 00173 ret.prob_begin()); 00174 } 00175 else 00176 { 00177 std::transform( region.classCounts().begin(), 00178 region.classCounts().end(), 00179 ext_param_.class_weights_.begin(), 00180 ret.prob_begin(), std::multiplies<double>()); 00181 } 00182 detail::Normalise<RF_Tag>::exec(ret.prob_begin(), ret.prob_end()); 00183 ret.weights() = region.size(); 00184 return e_ConstProbNode; 00185 } 00186 00187 00188 }; 00189 00190 /** Functor to sort the indices of a feature Matrix by a certain dimension 00191 **/ 00192 template<class DataMatrix> 00193 class SortSamplesByDimensions 00194 { 00195 DataMatrix const & data_; 00196 MultiArrayIndex sortColumn_; 00197 double thresVal_; 00198 public: 00199 00200 SortSamplesByDimensions(DataMatrix const & data, 00201 MultiArrayIndex sortColumn, 00202 double thresVal = 0.0) 00203 : data_(data), 00204 sortColumn_(sortColumn), 00205 thresVal_(thresVal) 00206 {} 00207 00208 void setColumn(MultiArrayIndex sortColumn) 00209 { 00210 sortColumn_ = sortColumn; 00211 } 00212 void setThreshold(double value) 00213 { 00214 thresVal_ = value; 00215 } 00216 00217 bool operator()(MultiArrayIndex l, MultiArrayIndex r) const 00218 { 00219 return data_(l, sortColumn_) < data_(r, sortColumn_); 00220 } 00221 bool operator()(MultiArrayIndex l) const 00222 { 00223 return data_(l, sortColumn_) < thresVal_; 00224 } 00225 }; 00226 00227 template<class DataMatrix> 00228 class DimensionNotEqual 00229 { 00230 DataMatrix const & data_; 00231 MultiArrayIndex sortColumn_; 00232 00233 public: 00234 00235 DimensionNotEqual(DataMatrix const & data, 00236 MultiArrayIndex sortColumn) 00237 : data_(data), 00238 sortColumn_(sortColumn) 00239 {} 00240 00241 void setColumn(MultiArrayIndex sortColumn) 00242 { 00243 sortColumn_ = sortColumn; 00244 } 00245 00246 bool operator()(MultiArrayIndex l, MultiArrayIndex r) const 00247 { 00248 return data_(l, sortColumn_) != data_(r, sortColumn_); 00249 } 00250 }; 00251 00252 template<class DataMatrix> 00253 class SortSamplesByHyperplane 00254 { 00255 DataMatrix const & data_; 00256 Node<i_HyperplaneNode> const & node_; 00257 00258 public: 00259 00260 SortSamplesByHyperplane(DataMatrix const & data, 00261 Node<i_HyperplaneNode> const & node) 00262 : 00263 data_(data), 00264 node_() 00265 {} 00266 00267 /** calculate the distance of a sample point to a hyperplane 00268 */ 00269 double operator[](MultiArrayIndex l) const 00270 { 00271 double result_l = -1 * node_.intercept(); 00272 for(int ii = 0; ii < node_.columns_size(); ++ii) 00273 { 00274 result_l += rowVector(data_, l)[node_.columns_begin()[ii]] 00275 * node_.weights()[ii]; 00276 } 00277 return result_l; 00278 } 00279 00280 bool operator()(MultiArrayIndex l, MultiArrayIndex r) const 00281 { 00282 return (*this)[l] < (*this)[r]; 00283 } 00284 00285 }; 00286 00287 /** makes a Class Histogram given indices in a labels_ array 00288 * usage: 00289 * MultiArrayView<2, T2, C2> labels = makeSomeLabels() 00290 * ArrayVector<int> hist(numberOfLabels(labels), 0); 00291 * RandomForestClassCounter<T2, C2, ArrayVector> counter(labels, hist); 00292 * 00293 * Container<int> indices = getSomeIndices() 00294 * std::for_each(indices, counter); 00295 */ 00296 template <class DataSource, class CountArray> 00297 class RandomForestClassCounter 00298 { 00299 DataSource const & labels_; 00300 CountArray & counts_; 00301 00302 public: 00303 00304 RandomForestClassCounter(DataSource const & labels, 00305 CountArray & counts) 00306 : labels_(labels), 00307 counts_(counts) 00308 { 00309 reset(); 00310 } 00311 00312 void reset() 00313 { 00314 counts_.init(0); 00315 } 00316 00317 void operator()(MultiArrayIndex l) const 00318 { 00319 counts_[labels_[l]] +=1; 00320 } 00321 }; 00322 00323 00324 /** Functor To Calculate the Best possible Split Based on the Gini Index 00325 given Labels and Features along a given Axis 00326 */ 00327 00328 namespace detail 00329 { 00330 template<int N> 00331 class ConstArr 00332 { 00333 public: 00334 double operator[](size_t) const 00335 { 00336 return (double)N; 00337 } 00338 }; 00339 00340 00341 } 00342 00343 00344 00345 00346 00347 /** Functor to calculate the gini impurity 00348 */ 00349 class GiniCriterion 00350 { 00351 public: 00352 /**caculate the weighted gini impurity based on class histogram 00353 * and class weights 00354 */ 00355 template<class Array, class Array2> 00356 double operator() (Array const & hist, 00357 Array2 const & weights, 00358 double total = 1.0) const 00359 { 00360 return impurity(hist, weights, total); 00361 } 00362 00363 /** calculate the gini based impurity based on class histogram 00364 */ 00365 template<class Array> 00366 double operator()(Array const & hist, double total = 1.0) const 00367 { 00368 return impurity(hist, total); 00369 } 00370 00371 /** static version of operator(hist total) 00372 */ 00373 template<class Array> 00374 static double impurity(Array const & hist, double total) 00375 { 00376 return impurity(hist, detail::ConstArr<1>(), total); 00377 } 00378 00379 /** static version of operator(hist, weights, total) 00380 */ 00381 template<class Array, class Array2> 00382 static double impurity (Array const & hist, 00383 Array2 const & weights, 00384 double total) 00385 { 00386 00387 int class_count = hist.size(); 00388 double gini = 0; 00389 if(class_count == 2) 00390 { 00391 double w = weights[0] * weights[1]; 00392 gini = w * (hist[0] * hist[1] / total); 00393 } 00394 else 00395 { 00396 for(int ii = 0; ii < class_count; ++ii) 00397 { 00398 double w = weights[ii]; 00399 gini += w*( hist[ii]*( 1.0 - w * hist[ii]/total ) ); 00400 } 00401 } 00402 return gini; 00403 } 00404 }; 00405 00406 00407 template <class DataSource, class Impurity= GiniCriterion> 00408 class ImpurityLoss 00409 { 00410 00411 DataSource const & labels_; 00412 ArrayVector<double> counts_; 00413 ArrayVector<double> const & class_weights_; 00414 double total_counts_; 00415 Impurity impurity_; 00416 00417 public: 00418 00419 template<class T> 00420 ImpurityLoss(DataSource const & labels, 00421 ProblemSpec<T> const & ext_) 00422 : labels_(labels), 00423 counts_(ext_.class_count_, 0.0), 00424 class_weights_(ext_.class_weights_), 00425 total_counts_(0.0) 00426 {} 00427 00428 void reset() 00429 { 00430 counts_.init(0); 00431 total_counts_ = 0.0; 00432 } 00433 00434 template<class Counts> 00435 double increment_histogram(Counts const & counts) 00436 { 00437 std::transform(counts.begin(), counts.end(), 00438 counts_.begin(), counts_.begin(), 00439 std::plus<double>()); 00440 total_counts_ = std::accumulate( counts_.begin(), 00441 counts_.end(), 00442 0.0); 00443 return impurity_(counts_, class_weights_, total_counts_); 00444 } 00445 00446 template<class Counts> 00447 double decrement_histogram(Counts const & counts) 00448 { 00449 std::transform(counts.begin(), counts.end(), 00450 counts_.begin(), counts_.begin(), 00451 std::minus<double>()); 00452 total_counts_ = std::accumulate( counts_.begin(), 00453 counts_.end(), 00454 0.0); 00455 return impurity_(counts_, class_weights_, total_counts_); 00456 } 00457 00458 template<class Iter> 00459 double increment(Iter begin, Iter end) 00460 { 00461 for(Iter iter = begin; iter != end; ++iter) 00462 { 00463 counts_[labels_[*iter]] +=1; 00464 total_counts_ +=1; 00465 } 00466 return impurity_(counts_, class_weights_, total_counts_); 00467 } 00468 00469 template<class Iter> 00470 double decrement(Iter begin, Iter end) 00471 { 00472 for(Iter iter = begin; iter != end; ++iter) 00473 { 00474 counts_[labels_[*iter]] -=1; 00475 total_counts_ -=1; 00476 } 00477 return impurity_(counts_, class_weights_, total_counts_); 00478 } 00479 00480 template<class Iter, class Resp_t> 00481 double init (Iter begin, Iter end, Resp_t resp) 00482 { 00483 reset(); 00484 std::copy(resp.begin(), resp.end(), counts_.begin()); 00485 total_counts_ = std::accumulate(counts_.begin(), counts_.end(), 0.0); 00486 return impurity_(counts_,class_weights_, total_counts_); 00487 } 00488 00489 ArrayVector<double> const & response() 00490 { 00491 return counts_; 00492 } 00493 }; 00494 00495 template <class DataSource> 00496 class RegressionForestCounter 00497 { 00498 typedef MultiArrayShape<2>::type Shp; 00499 DataSource const & labels_; 00500 ArrayVector <double> mean_; 00501 ArrayVector <double> variance_; 00502 ArrayVector <double> tmp_; 00503 size_t count_; 00504 00505 template<class T> 00506 RegressionForestCounter(DataSource const & labels, 00507 ProblemSpec<T> const & ext_) 00508 : 00509 labels_(labels), 00510 mean_(ext_.response_size, 0.0), 00511 variance_(ext_.response_size, 0.0), 00512 tmp_(ext_.response_size), 00513 count_(0) 00514 {} 00515 00516 // west's alorithm for incremental variance 00517 // calculation 00518 template<class Iter> 00519 double increment (Iter begin, Iter end) 00520 { 00521 for(Iter iter = begin; iter != end; ++iter) 00522 { 00523 ++count_; 00524 for(int ii = 0; ii < mean_.size(); ++ii) 00525 tmp_[ii] = labels_(*iter, ii) - mean_[ii]; 00526 double f = 1.0 / count_, 00527 f1 = 1.0 - f; 00528 for(int ii = 0; ii < mean_.size(); ++ii) 00529 mean_[ii] += f*tmp_[ii]; 00530 for(int ii = 0; ii < mean_.size(); ++ii) 00531 variance_[ii] += f1*sq(tmp_[ii]); 00532 } 00533 return std::accumulate(variance_.begin(), 00534 variance_.end(), 00535 0.0, 00536 std::plus<double>()) 00537 /(count_ -1); 00538 } 00539 00540 template<class Iter> 00541 double decrement (Iter begin, Iter end) 00542 { 00543 for(Iter iter = begin; iter != end; ++iter) 00544 { 00545 --count_; 00546 for(int ii = 0; ii < mean_.size(); ++ii) 00547 tmp_[ii] = labels_(*iter, ii) - mean_[ii]; 00548 double f = 1.0 / count_, 00549 f1 = 1.0 + f; 00550 for(int ii = 0; ii < mean_.size(); ++ii) 00551 mean_[ii] -= f*tmp_[ii]; 00552 for(int ii = 0; ii < mean_.size(); ++ii) 00553 variance_[ii] -= f1*sq(tmp_[ii]); 00554 } 00555 return std::accumulate(variance_.begin(), 00556 variance_.end(), 00557 0.0, 00558 std::plus<double>()) 00559 /(count_ -1); 00560 } 00561 00562 template<class Iter, class Resp_t> 00563 double init (Iter begin, Iter end, Resp_t resp) 00564 { 00565 reset(); 00566 return increment(begin, end); 00567 } 00568 00569 00570 ArrayVector<double> const & response() 00571 { 00572 return mean_; 00573 } 00574 00575 void reset() 00576 { 00577 mean_.init(0.0); 00578 variance_.init(0.0); 00579 count_ = 0; 00580 } 00581 }; 00582 00583 template<class Tag, class Datatyp> 00584 struct LossTraits; 00585 00586 struct LSQLoss 00587 {}; 00588 00589 template<class Datatype> 00590 struct LossTraits<GiniCriterion, Datatype> 00591 { 00592 typedef ImpurityLoss<Datatype, GiniCriterion> type; 00593 }; 00594 00595 template<class Datatype> 00596 struct LossTraits<LSQLoss, Datatype> 00597 { 00598 typedef RegressionForestCounter<Datatype> type; 00599 }; 00600 00601 template<class LineSearchLossTag> 00602 class BestGiniOfColumn 00603 { 00604 public: 00605 ArrayVector<double> class_weights_; 00606 ArrayVector<double> bestCurrentCounts[2]; 00607 double min_gini_; 00608 ptrdiff_t min_index_; 00609 double min_threshold_; 00610 ProblemSpec<> ext_param_; 00611 00612 BestGiniOfColumn() 00613 {} 00614 00615 template<class T> 00616 BestGiniOfColumn(ProblemSpec<T> const & ext) 00617 : 00618 class_weights_(ext.class_weights_), 00619 ext_param_(ext) 00620 { 00621 bestCurrentCounts[0].resize(ext.class_count_); 00622 bestCurrentCounts[1].resize(ext.class_count_); 00623 } 00624 template<class T> 00625 void set_external_parameters(ProblemSpec<T> const & ext) 00626 { 00627 class_weights_ = ext.class_weights_; 00628 ext_param_ = ext; 00629 bestCurrentCounts[0].resize(ext.class_count_); 00630 bestCurrentCounts[1].resize(ext.class_count_); 00631 } 00632 /** calculate the best gini split along a Feature Column 00633 * \param column, the feature vector - has to support the [] operator 00634 * \param labels, the label vector 00635 * \param begin 00636 * \param end (in and out) 00637 * begin and end iterators to the indices of the 00638 * samples in the current region. 00639 * the range begin - end is sorted by the column supplied 00640 * during function execution. 00641 * \param class_counts 00642 * class histogram of the range. 00643 * 00644 * precondition: begin, end valid range, 00645 * class_counts positive integer valued array with the 00646 * class counts in the current range. 00647 * labels.size() >= max(begin, end); 00648 * postcondition: 00649 * begin, end sorted by column given. 00650 * min_gini_ contains the minimum gini found or 00651 * NumericTraits<double>::max if no split was found. 00652 * min_index_ countains the splitting index in the range 00653 * or invalid data if no split was found. 00654 * BestCirremtcounts[0] and [1] contain the 00655 * class histogram of the left and right region of 00656 * the left and right regions. 00657 */ 00658 template< class DataSourceF_t, 00659 class DataSource_t, 00660 class I_Iter, 00661 class Array> 00662 void operator()(DataSourceF_t const & column, 00663 DataSource_t const & labels, 00664 I_Iter & begin, 00665 I_Iter & end, 00666 Array const & region_response) 00667 { 00668 std::sort(begin, end, 00669 SortSamplesByDimensions<DataSourceF_t>(column, 0)); 00670 typedef typename 00671 LossTraits<LineSearchLossTag, DataSource_t>::type LineSearchLoss; 00672 LineSearchLoss left(labels, ext_param_); 00673 LineSearchLoss right(labels, ext_param_); 00674 00675 00676 00677 min_gini_ = right.init(begin, end, region_response); 00678 min_threshold_ = *begin; 00679 min_index_ = 0; 00680 DimensionNotEqual<DataSourceF_t> comp(column, 0); 00681 00682 I_Iter iter = begin; 00683 I_Iter next = std::adjacent_find(iter, end, comp); 00684 while( next != end) 00685 { 00686 00687 double loss = right.decrement(iter, next + 1) 00688 + left.increment(iter , next + 1); 00689 if(loss < min_gini_ && !closeAtTolerance(loss, min_gini_)) 00690 { 00691 bestCurrentCounts[0] = left.response(); 00692 bestCurrentCounts[1] = right.response(); 00693 min_gini_ = loss < min_gini_? loss : min_gini_; 00694 min_index_ = next - begin +1 ; 00695 min_threshold_ = (column[*next] + column[*(next +1)])/2; 00696 } 00697 iter = next +1 ; 00698 next = std::adjacent_find(iter, end, comp); 00699 } 00700 00701 } 00702 00703 template<class DataSource_t, class Iter, class Array> 00704 double loss_of_region(DataSource_t const & labels, 00705 Iter & begin, 00706 Iter & end, 00707 Array const & region_response) const 00708 { 00709 typedef typename 00710 LossTraits<LineSearchLossTag, DataSource_t>::type LineSearchLoss; 00711 LineSearchLoss region_loss(labels, ext_param_); 00712 return 00713 region_loss.init(begin, end, region_response); 00714 } 00715 00716 }; 00717 00718 template<class ColumnDecisionFunctor, class Tag> 00719 class ThresholdSplit: public SplitBase<Tag> 00720 { 00721 public: 00722 00723 00724 typedef SplitBase<Tag> SB; 00725 00726 ArrayVector<Int32> splitColumns; 00727 ColumnDecisionFunctor bgfunc; 00728 00729 double region_gini_; 00730 ArrayVector<double> min_gini_; 00731 ArrayVector<ptrdiff_t> min_indices_; 00732 ArrayVector<double> min_thresholds_; 00733 00734 int bestSplitIndex; 00735 00736 double minGini() const 00737 { 00738 return min_gini_[bestSplitIndex]; 00739 } 00740 int bestSplitColumn() const 00741 { 00742 return splitColumns[bestSplitIndex]; 00743 } 00744 double bestSplitThreshold() const 00745 { 00746 return min_thresholds_[bestSplitIndex]; 00747 } 00748 00749 template<class T> 00750 void set_external_parameters(ProblemSpec<T> const & in) 00751 { 00752 SB::set_external_parameters(in); 00753 bgfunc.set_external_parameters( SB::ext_param_); 00754 int featureCount_ = SB::ext_param_.column_count_; 00755 splitColumns.resize(featureCount_); 00756 for(int k=0; k<featureCount_; ++k) 00757 splitColumns[k] = k; 00758 min_gini_.resize(featureCount_); 00759 min_indices_.resize(featureCount_); 00760 min_thresholds_.resize(featureCount_); 00761 } 00762 00763 00764 template<class T, class C, class T2, class C2, class Region, class Random> 00765 int findBestSplit(MultiArrayView<2, T, C> features, 00766 MultiArrayView<2, T2, C2> labels, 00767 Region & region, 00768 ArrayVector<Region>& childRegions, 00769 Random & randint) 00770 { 00771 00772 typedef typename Region::IndexIterator IndexIterator; 00773 if(region.size() == 0) 00774 { 00775 std::cerr << "SplitFunctor::findBestSplit(): stackentry with 0 examples encountered\n" 00776 "continuing learning process...."; 00777 } 00778 // calculate things that haven't been calculated yet. 00779 00780 if(std::accumulate(region.classCounts().begin(), 00781 region.classCounts().end(), 0) != region.size()) 00782 { 00783 RandomForestClassCounter< MultiArrayView<2,T2, C2>, 00784 ArrayVector<Int32> > 00785 counter(labels, region.classCounts()); 00786 std::for_each( region.begin(), region.end(), counter); 00787 region.classCountsIsValid = true; 00788 } 00789 00790 // Is the region pure already? 00791 region_gini_ = bgfunc.loss_of_region(labels, 00792 region.begin(), 00793 region.end(), 00794 region.classCounts()); 00795 if(region_gini_ <= SB::ext_param_.precision_) 00796 return makeTerminalNode(features, labels, region, randint); 00797 00798 // select columns to be tried. 00799 for(int ii = 0; ii < SB::ext_param_.actual_mtry_; ++ii) 00800 std::swap(splitColumns[ii], 00801 splitColumns[ii+ randint(features.shape(1) - ii)]); 00802 00803 // find the best gini index 00804 bestSplitIndex = 0; 00805 double current_min_gini = region_gini_; 00806 int num2try = features.shape(1); 00807 for(int k=0; k<num2try; ++k) 00808 { 00809 //this functor does all the work 00810 bgfunc(columnVector(features, splitColumns[k]), 00811 labels, 00812 region.begin(), region.end(), 00813 region.classCounts()); 00814 min_gini_[k] = bgfunc.min_gini_; 00815 min_indices_[k] = bgfunc.min_index_; 00816 min_thresholds_[k] = bgfunc.min_threshold_; 00817 00818 if( bgfunc.min_gini_ < current_min_gini 00819 && !closeAtTolerance(bgfunc.min_gini_, current_min_gini)) 00820 { 00821 current_min_gini = bgfunc.min_gini_; 00822 childRegions[0].classCounts() = bgfunc.bestCurrentCounts[0]; 00823 childRegions[1].classCounts() = bgfunc.bestCurrentCounts[1]; 00824 childRegions[0].classCountsIsValid = true; 00825 childRegions[1].classCountsIsValid = true; 00826 00827 bestSplitIndex = k; 00828 num2try = SB::ext_param_.actual_mtry_; 00829 } 00830 } 00831 00832 // did not find any suitable split 00833 if(closeAtTolerance(current_min_gini, region_gini_)) 00834 return makeTerminalNode(features, labels, region, randint); 00835 00836 //create a Node for output 00837 Node<i_ThresholdNode> node(SB::t_data, SB::p_data); 00838 SB::node_ = node; 00839 node.threshold() = min_thresholds_[bestSplitIndex]; 00840 node.column() = splitColumns[bestSplitIndex]; 00841 00842 // partition the range according to the best dimension 00843 SortSamplesByDimensions<MultiArrayView<2, T, C> > 00844 sorter(features, node.column(), node.threshold()); 00845 IndexIterator bestSplit = 00846 std::partition(region.begin(), region.end(), sorter); 00847 // Save the ranges of the child stack entries. 00848 childRegions[0].setRange( region.begin() , bestSplit ); 00849 childRegions[0].rule = region.rule; 00850 childRegions[0].rule.push_back(std::make_pair(1, 1.0)); 00851 childRegions[1].setRange( bestSplit , region.end() ); 00852 childRegions[1].rule = region.rule; 00853 childRegions[1].rule.push_back(std::make_pair(1, 1.0)); 00854 00855 return i_ThresholdNode; 00856 } 00857 }; 00858 00859 typedef ThresholdSplit<BestGiniOfColumn<GiniCriterion> > GiniSplit; 00860 typedef ThresholdSplit<BestGiniOfColumn<LSQLoss>, RegressionTag> RegressionSplit; 00861 00862 } //namespace vigra 00863 #endif // VIGRA_RANDOM_FOREST_SPLIT_HXX
© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de) |
html generated using doxygen and Python
|