[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]
00001 /************************************************************************/ 00002 /* */ 00003 /* Copyright 2008-2009 by Ullrich Koethe and Rahul Nair */ 00004 /* */ 00005 /* This file is part of the VIGRA computer vision library. */ 00006 /* The VIGRA Website is */ 00007 /* http://hci.iwr.uni-heidelberg.de/vigra/ */ 00008 /* Please direct questions, bug reports, and contributions to */ 00009 /* ullrich.koethe@iwr.uni-heidelberg.de or */ 00010 /* vigra@informatik.uni-hamburg.de */ 00011 /* */ 00012 /* Permission is hereby granted, free of charge, to any person */ 00013 /* obtaining a copy of this software and associated documentation */ 00014 /* files (the "Software"), to deal in the Software without */ 00015 /* restriction, including without limitation the rights to use, */ 00016 /* copy, modify, merge, publish, distribute, sublicense, and/or */ 00017 /* sell copies of the Software, and to permit persons to whom the */ 00018 /* Software is furnished to do so, subject to the following */ 00019 /* conditions: */ 00020 /* */ 00021 /* The above copyright notice and this permission notice shall be */ 00022 /* included in all copies or substantial portions of the */ 00023 /* Software. */ 00024 /* */ 00025 /* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND */ 00026 /* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES */ 00027 /* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND */ 00028 /* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT */ 00029 /* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, */ 00030 /* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING */ 00031 /* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR */ 00032 /* OTHER DEALINGS IN THE SOFTWARE. */ 00033 /* */ 00034 /************************************************************************/ 00035 #ifndef VIGRA_RANDOM_FOREST_SPLIT_HXX 00036 #define VIGRA_RANDOM_FOREST_SPLIT_HXX 00037 #include <algorithm> 00038 #include <cstddef> 00039 #include <map> 00040 #include <numeric> 00041 #include <math.h> 00042 #include "../mathutil.hxx" 00043 #include "../array_vector.hxx" 00044 #include "../sized_int.hxx" 00045 #include "../matrix.hxx" 00046 #include "../random.hxx" 00047 #include "../functorexpression.hxx" 00048 #include "rf_nodeproxy.hxx" 00049 //#include "rf_sampling.hxx" 00050 #include "rf_region.hxx" 00051 //#include "../hokashyap.hxx" 00052 //#include "vigra/rf_helpers.hxx" 00053 00054 namespace vigra 00055 { 00056 00057 // Incomplete Class to ensure that findBestSplit is always implemented in 00058 // the derived classes of SplitBase 00059 class CompileTimeError; 00060 00061 00062 namespace detail 00063 { 00064 template<class Tag> 00065 class Normalise 00066 { 00067 public: 00068 template<class Iter> 00069 static void exec(Iter begin, Iter end) 00070 {} 00071 }; 00072 00073 template<> 00074 class Normalise<ClassificationTag> 00075 { 00076 public: 00077 template<class Iter> 00078 static void exec (Iter begin, Iter end) 00079 { 00080 double bla = std::accumulate(begin, end, 0.0); 00081 for(int ii = 0; ii < end - begin; ++ii) 00082 begin[ii] = begin[ii]/bla ; 00083 } 00084 }; 00085 } 00086 00087 00088 /** Base Class for all SplitFunctors used with the \ref RandomForest class 00089 defines the interface used while learning a tree. 00090 **/ 00091 template<class Tag> 00092 class SplitBase 00093 { 00094 public: 00095 00096 typedef Tag RF_Tag; 00097 typedef DT_StackEntry<ArrayVectorView<Int32>::iterator> 00098 StackEntry_t; 00099 00100 ProblemSpec<> ext_param_; 00101 00102 NodeBase::T_Container_type t_data; 00103 NodeBase::P_Container_type p_data; 00104 00105 NodeBase node_; 00106 00107 /** returns the DecisionTree Node created by 00108 \ref findBestSplit or \ref makeTerminalNode. 00109 **/ 00110 00111 template<class T> 00112 void set_external_parameters(ProblemSpec<T> const & in) 00113 { 00114 ext_param_ = in; 00115 t_data.push_back(in.column_count_); 00116 t_data.push_back(in.class_count_); 00117 } 00118 00119 NodeBase & createNode() 00120 { 00121 return node_; 00122 } 00123 00124 int classCount() const 00125 { 00126 return int(t_data[1]); 00127 } 00128 00129 int featureCount() const 00130 { 00131 return int(t_data[0]); 00132 } 00133 00134 /** resets internal data. Should always be called before 00135 calling findBestSplit or makeTerminalNode 00136 **/ 00137 void reset() 00138 { 00139 t_data.resize(2); 00140 p_data.resize(0); 00141 } 00142 00143 00144 /** findBestSplit has to be implemented in derived split functor. 00145 these functions only insures That a CompileTime error is issued 00146 if no such method was defined. 00147 **/ 00148 00149 template<class T, class C, class T2, class C2, class Region, class Random> 00150 int findBestSplit(MultiArrayView<2, T, C> features, 00151 MultiArrayView<2, T2, C2> labels, 00152 Region region, 00153 ArrayVector<Region> childs, 00154 Random randint) 00155 { 00156 CompileTimeError SplitFunctor__findBestSplit_member_was_not_defined; 00157 return 0; 00158 } 00159 00160 /** default action for creating a terminal Node. 00161 sets the Class probability of the remaining region according to 00162 the class histogram 00163 **/ 00164 template<class T, class C, class T2,class C2, class Region, class Random> 00165 int makeTerminalNode(MultiArrayView<2, T, C> features, 00166 MultiArrayView<2, T2, C2> labels, 00167 Region & region, 00168 Random randint) 00169 { 00170 Node<e_ConstProbNode> ret(t_data, p_data); 00171 node_ = ret; 00172 if(ext_param_.class_weights_.size() != region.classCounts().size()) 00173 { 00174 std::copy( region.classCounts().begin(), 00175 region.classCounts().end(), 00176 ret.prob_begin()); 00177 } 00178 else 00179 { 00180 std::transform( region.classCounts().begin(), 00181 region.classCounts().end(), 00182 ext_param_.class_weights_.begin(), 00183 ret.prob_begin(), std::multiplies<double>()); 00184 } 00185 detail::Normalise<RF_Tag>::exec(ret.prob_begin(), ret.prob_end()); 00186 ret.weights() = region.size(); 00187 return e_ConstProbNode; 00188 } 00189 00190 00191 }; 00192 00193 /** Functor to sort the indices of a feature Matrix by a certain dimension 00194 **/ 00195 template<class DataMatrix> 00196 class SortSamplesByDimensions 00197 { 00198 DataMatrix const & data_; 00199 MultiArrayIndex sortColumn_; 00200 double thresVal_; 00201 public: 00202 00203 SortSamplesByDimensions(DataMatrix const & data, 00204 MultiArrayIndex sortColumn, 00205 double thresVal = 0.0) 00206 : data_(data), 00207 sortColumn_(sortColumn), 00208 thresVal_(thresVal) 00209 {} 00210 00211 void setColumn(MultiArrayIndex sortColumn) 00212 { 00213 sortColumn_ = sortColumn; 00214 } 00215 void setThreshold(double value) 00216 { 00217 thresVal_ = value; 00218 } 00219 00220 bool operator()(MultiArrayIndex l, MultiArrayIndex r) const 00221 { 00222 return data_(l, sortColumn_) < data_(r, sortColumn_); 00223 } 00224 bool operator()(MultiArrayIndex l) const 00225 { 00226 return data_(l, sortColumn_) < thresVal_; 00227 } 00228 }; 00229 00230 template<class DataMatrix> 00231 class DimensionNotEqual 00232 { 00233 DataMatrix const & data_; 00234 MultiArrayIndex sortColumn_; 00235 00236 public: 00237 00238 DimensionNotEqual(DataMatrix const & data, 00239 MultiArrayIndex sortColumn) 00240 : data_(data), 00241 sortColumn_(sortColumn) 00242 {} 00243 00244 void setColumn(MultiArrayIndex sortColumn) 00245 { 00246 sortColumn_ = sortColumn; 00247 } 00248 00249 bool operator()(MultiArrayIndex l, MultiArrayIndex r) const 00250 { 00251 return data_(l, sortColumn_) != data_(r, sortColumn_); 00252 } 00253 }; 00254 00255 template<class DataMatrix> 00256 class SortSamplesByHyperplane 00257 { 00258 DataMatrix const & data_; 00259 Node<i_HyperplaneNode> const & node_; 00260 00261 public: 00262 00263 SortSamplesByHyperplane(DataMatrix const & data, 00264 Node<i_HyperplaneNode> const & node) 00265 : 00266 data_(data), 00267 node_() 00268 {} 00269 00270 /** calculate the distance of a sample point to a hyperplane 00271 */ 00272 double operator[](MultiArrayIndex l) const 00273 { 00274 double result_l = -1 * node_.intercept(); 00275 for(int ii = 0; ii < node_.columns_size(); ++ii) 00276 { 00277 result_l += rowVector(data_, l)[node_.columns_begin()[ii]] 00278 * node_.weights()[ii]; 00279 } 00280 return result_l; 00281 } 00282 00283 bool operator()(MultiArrayIndex l, MultiArrayIndex r) const 00284 { 00285 return (*this)[l] < (*this)[r]; 00286 } 00287 00288 }; 00289 00290 /** makes a Class Histogram given indices in a labels_ array 00291 * usage: 00292 * MultiArrayView<2, T2, C2> labels = makeSomeLabels() 00293 * ArrayVector<int> hist(numberOfLabels(labels), 0); 00294 * RandomForestClassCounter<T2, C2, ArrayVector> counter(labels, hist); 00295 * 00296 * Container<int> indices = getSomeIndices() 00297 * std::for_each(indices, counter); 00298 */ 00299 template <class DataSource, class CountArray> 00300 class RandomForestClassCounter 00301 { 00302 DataSource const & labels_; 00303 CountArray & counts_; 00304 00305 public: 00306 00307 RandomForestClassCounter(DataSource const & labels, 00308 CountArray & counts) 00309 : labels_(labels), 00310 counts_(counts) 00311 { 00312 reset(); 00313 } 00314 00315 void reset() 00316 { 00317 counts_.init(0); 00318 } 00319 00320 void operator()(MultiArrayIndex l) const 00321 { 00322 counts_[labels_[l]] +=1; 00323 } 00324 }; 00325 00326 00327 /** Functor To Calculate the Best possible Split Based on the Gini Index 00328 given Labels and Features along a given Axis 00329 */ 00330 00331 namespace detail 00332 { 00333 template<int N> 00334 class ConstArr 00335 { 00336 public: 00337 double operator[](size_t) const 00338 { 00339 return (double)N; 00340 } 00341 }; 00342 00343 00344 } 00345 00346 00347 00348 00349 /** Functor to calculate the entropy based impurity 00350 */ 00351 class EntropyCriterion 00352 { 00353 public: 00354 /**caculate the weighted gini impurity based on class histogram 00355 * and class weights 00356 */ 00357 template<class Array, class Array2> 00358 double operator() (Array const & hist, 00359 Array2 const & weights, 00360 double total = 1.0) const 00361 { 00362 return impurity(hist, weights, total); 00363 } 00364 00365 /** calculate the gini based impurity based on class histogram 00366 */ 00367 template<class Array> 00368 double operator()(Array const & hist, double total = 1.0) const 00369 { 00370 return impurity(hist, total); 00371 } 00372 00373 /** static version of operator(hist total) 00374 */ 00375 template<class Array> 00376 static double impurity(Array const & hist, double total) 00377 { 00378 return impurity(hist, detail::ConstArr<1>(), total); 00379 } 00380 00381 /** static version of operator(hist, weights, total) 00382 */ 00383 template<class Array, class Array2> 00384 static double impurity (Array const & hist, 00385 Array2 const & weights, 00386 double total) 00387 { 00388 00389 int class_count = hist.size(); 00390 double entropy = 0.0; 00391 if(class_count == 2) 00392 { 00393 double p0 = (hist[0]/total); 00394 double p1 = (hist[1]/total); 00395 entropy = 0 - weights[0]*p0*std::log(p0) - weights[1]*p1*std::log(p1); 00396 } 00397 else 00398 { 00399 for(int ii = 0; ii < class_count; ++ii) 00400 { 00401 double w = weights[ii]; 00402 double pii = hist[ii]/total; 00403 entropy -= w*( pii*std::log(pii)); 00404 } 00405 } 00406 entropy = total * entropy; 00407 return entropy; 00408 } 00409 }; 00410 00411 /** Functor to calculate the gini impurity 00412 */ 00413 class GiniCriterion 00414 { 00415 public: 00416 /**caculate the weighted gini impurity based on class histogram 00417 * and class weights 00418 */ 00419 template<class Array, class Array2> 00420 double operator() (Array const & hist, 00421 Array2 const & weights, 00422 double total = 1.0) const 00423 { 00424 return impurity(hist, weights, total); 00425 } 00426 00427 /** calculate the gini based impurity based on class histogram 00428 */ 00429 template<class Array> 00430 double operator()(Array const & hist, double total = 1.0) const 00431 { 00432 return impurity(hist, total); 00433 } 00434 00435 /** static version of operator(hist total) 00436 */ 00437 template<class Array> 00438 static double impurity(Array const & hist, double total) 00439 { 00440 return impurity(hist, detail::ConstArr<1>(), total); 00441 } 00442 00443 /** static version of operator(hist, weights, total) 00444 */ 00445 template<class Array, class Array2> 00446 static double impurity (Array const & hist, 00447 Array2 const & weights, 00448 double total) 00449 { 00450 00451 int class_count = hist.size(); 00452 double gini = 0.0; 00453 if(class_count == 2) 00454 { 00455 double w = weights[0] * weights[1]; 00456 gini = w * (hist[0] * hist[1] / total); 00457 } 00458 else 00459 { 00460 for(int ii = 0; ii < class_count; ++ii) 00461 { 00462 double w = weights[ii]; 00463 gini += w*( hist[ii]*( 1.0 - w * hist[ii]/total ) ); 00464 } 00465 } 00466 return gini; 00467 } 00468 }; 00469 00470 00471 template <class DataSource, class Impurity= GiniCriterion> 00472 class ImpurityLoss 00473 { 00474 00475 DataSource const & labels_; 00476 ArrayVector<double> counts_; 00477 ArrayVector<double> const class_weights_; 00478 double total_counts_; 00479 Impurity impurity_; 00480 00481 public: 00482 00483 template<class T> 00484 ImpurityLoss(DataSource const & labels, 00485 ProblemSpec<T> const & ext_) 00486 : labels_(labels), 00487 counts_(ext_.class_count_, 0.0), 00488 class_weights_(ext_.class_weights_), 00489 total_counts_(0.0) 00490 {} 00491 00492 void reset() 00493 { 00494 counts_.init(0); 00495 total_counts_ = 0.0; 00496 } 00497 00498 template<class Counts> 00499 double increment_histogram(Counts const & counts) 00500 { 00501 std::transform(counts.begin(), counts.end(), 00502 counts_.begin(), counts_.begin(), 00503 std::plus<double>()); 00504 total_counts_ = std::accumulate( counts_.begin(), 00505 counts_.end(), 00506 0.0); 00507 return impurity_(counts_, class_weights_, total_counts_); 00508 } 00509 00510 template<class Counts> 00511 double decrement_histogram(Counts const & counts) 00512 { 00513 std::transform(counts.begin(), counts.end(), 00514 counts_.begin(), counts_.begin(), 00515 std::minus<double>()); 00516 total_counts_ = std::accumulate( counts_.begin(), 00517 counts_.end(), 00518 0.0); 00519 return impurity_(counts_, class_weights_, total_counts_); 00520 } 00521 00522 template<class Iter> 00523 double increment(Iter begin, Iter end) 00524 { 00525 for(Iter iter = begin; iter != end; ++iter) 00526 { 00527 counts_[labels_(*iter, 0)] +=1.0; 00528 total_counts_ +=1.0; 00529 } 00530 return impurity_(counts_, class_weights_, total_counts_); 00531 } 00532 00533 template<class Iter> 00534 double decrement(Iter const & begin, Iter const & end) 00535 { 00536 for(Iter iter = begin; iter != end; ++iter) 00537 { 00538 counts_[labels_(*iter,0)] -=1.0; 00539 total_counts_ -=1.0; 00540 } 00541 return impurity_(counts_, class_weights_, total_counts_); 00542 } 00543 00544 template<class Iter, class Resp_t> 00545 double init (Iter begin, Iter end, Resp_t resp) 00546 { 00547 reset(); 00548 std::copy(resp.begin(), resp.end(), counts_.begin()); 00549 total_counts_ = std::accumulate(counts_.begin(), counts_.end(), 0.0); 00550 return impurity_(counts_,class_weights_, total_counts_); 00551 } 00552 00553 ArrayVector<double> const & response() 00554 { 00555 return counts_; 00556 } 00557 }; 00558 00559 template <class DataSource> 00560 class RegressionForestCounter 00561 { 00562 typedef MultiArrayShape<2>::type Shp; 00563 DataSource const & labels_; 00564 ArrayVector <double> mean_; 00565 ArrayVector <double> variance_; 00566 ArrayVector <double> tmp_; 00567 size_t count_; 00568 00569 template<class T> 00570 RegressionForestCounter(DataSource const & labels, 00571 ProblemSpec<T> const & ext_) 00572 : 00573 labels_(labels), 00574 mean_(ext_.response_size, 0.0), 00575 variance_(ext_.response_size, 0.0), 00576 tmp_(ext_.response_size), 00577 count_(0) 00578 {} 00579 00580 // west's alorithm for incremental variance 00581 // calculation 00582 template<class Iter> 00583 double increment (Iter begin, Iter end) 00584 { 00585 for(Iter iter = begin; iter != end; ++iter) 00586 { 00587 ++count_; 00588 for(int ii = 0; ii < mean_.size(); ++ii) 00589 tmp_[ii] = labels_(*iter, ii) - mean_[ii]; 00590 double f = 1.0 / count_, 00591 f1 = 1.0 - f; 00592 for(int ii = 0; ii < mean_.size(); ++ii) 00593 mean_[ii] += f*tmp_[ii]; 00594 for(int ii = 0; ii < mean_.size(); ++ii) 00595 variance_[ii] += f1*sq(tmp_[ii]); 00596 } 00597 return std::accumulate(variance_.begin(), 00598 variance_.end(), 00599 0.0, 00600 std::plus<double>()) 00601 /(count_ -1); 00602 } 00603 00604 template<class Iter> 00605 double decrement (Iter begin, Iter end) 00606 { 00607 for(Iter iter = begin; iter != end; ++iter) 00608 { 00609 --count_; 00610 for(int ii = 0; ii < mean_.size(); ++ii) 00611 tmp_[ii] = labels_(*iter, ii) - mean_[ii]; 00612 double f = 1.0 / count_, 00613 f1 = 1.0 + f; 00614 for(int ii = 0; ii < mean_.size(); ++ii) 00615 mean_[ii] -= f*tmp_[ii]; 00616 for(int ii = 0; ii < mean_.size(); ++ii) 00617 variance_[ii] -= f1*sq(tmp_[ii]); 00618 } 00619 return std::accumulate(variance_.begin(), 00620 variance_.end(), 00621 0.0, 00622 std::plus<double>()) 00623 /(count_ -1); 00624 } 00625 00626 template<class Iter, class Resp_t> 00627 double init (Iter begin, Iter end, Resp_t resp) 00628 { 00629 reset(); 00630 return increment(begin, end); 00631 } 00632 00633 00634 ArrayVector<double> const & response() 00635 { 00636 return mean_; 00637 } 00638 00639 void reset() 00640 { 00641 mean_.init(0.0); 00642 variance_.init(0.0); 00643 count_ = 0; 00644 } 00645 }; 00646 00647 template<class Tag, class Datatyp> 00648 struct LossTraits; 00649 00650 struct LSQLoss 00651 {}; 00652 00653 template<class Datatype> 00654 struct LossTraits<GiniCriterion, Datatype> 00655 { 00656 typedef ImpurityLoss<Datatype, GiniCriterion> type; 00657 }; 00658 00659 template<class Datatype> 00660 struct LossTraits<EntropyCriterion, Datatype> 00661 { 00662 typedef ImpurityLoss<Datatype, EntropyCriterion> type; 00663 }; 00664 00665 template<class Datatype> 00666 struct LossTraits<LSQLoss, Datatype> 00667 { 00668 typedef RegressionForestCounter<Datatype> type; 00669 }; 00670 00671 /** Given a column, choose a split that minimizes some loss 00672 */ 00673 template<class LineSearchLossTag> 00674 class BestGiniOfColumn 00675 { 00676 public: 00677 ArrayVector<double> class_weights_; 00678 ArrayVector<double> bestCurrentCounts[2]; 00679 double min_gini_; 00680 ptrdiff_t min_index_; 00681 double min_threshold_; 00682 ProblemSpec<> ext_param_; 00683 00684 BestGiniOfColumn() 00685 {} 00686 00687 template<class T> 00688 BestGiniOfColumn(ProblemSpec<T> const & ext) 00689 : 00690 class_weights_(ext.class_weights_), 00691 ext_param_(ext) 00692 { 00693 bestCurrentCounts[0].resize(ext.class_count_); 00694 bestCurrentCounts[1].resize(ext.class_count_); 00695 } 00696 template<class T> 00697 void set_external_parameters(ProblemSpec<T> const & ext) 00698 { 00699 class_weights_ = ext.class_weights_; 00700 ext_param_ = ext; 00701 bestCurrentCounts[0].resize(ext.class_count_); 00702 bestCurrentCounts[1].resize(ext.class_count_); 00703 } 00704 /** calculate the best gini split along a Feature Column 00705 * \param column, the feature vector - has to support the [] operator 00706 * \param labels, the label vector 00707 * \param begin 00708 * \param end (in and out) 00709 * begin and end iterators to the indices of the 00710 * samples in the current region. 00711 * the range begin - end is sorted by the column supplied 00712 * during function execution. 00713 * \param class_counts 00714 * class histogram of the range. 00715 * 00716 * precondition: begin, end valid range, 00717 * class_counts positive integer valued array with the 00718 * class counts in the current range. 00719 * labels.size() >= max(begin, end); 00720 * postcondition: 00721 * begin, end sorted by column given. 00722 * min_gini_ contains the minimum gini found or 00723 * NumericTraits<double>::max if no split was found. 00724 * min_index_ countains the splitting index in the range 00725 * or invalid data if no split was found. 00726 * BestCirremtcounts[0] and [1] contain the 00727 * class histogram of the left and right region of 00728 * the left and right regions. 00729 */ 00730 template< class DataSourceF_t, 00731 class DataSource_t, 00732 class I_Iter, 00733 class Array> 00734 void operator()(DataSourceF_t const & column, 00735 int g, 00736 DataSource_t const & labels, 00737 I_Iter & begin, 00738 I_Iter & end, 00739 Array const & region_response) 00740 { 00741 std::sort(begin, end, 00742 SortSamplesByDimensions<DataSourceF_t>(column, g)); 00743 typedef typename 00744 LossTraits<LineSearchLossTag, DataSource_t>::type LineSearchLoss; 00745 LineSearchLoss left(labels, ext_param_); 00746 LineSearchLoss right(labels, ext_param_); 00747 00748 00749 00750 min_gini_ = right.init(begin, end, region_response); 00751 min_threshold_ = *begin; 00752 min_index_ = 0; 00753 DimensionNotEqual<DataSourceF_t> comp(column, g); 00754 00755 I_Iter iter = begin; 00756 I_Iter next = std::adjacent_find(iter, end, comp); 00757 while( next != end) 00758 { 00759 00760 double loss = right.decrement(iter, next + 1) 00761 + left.increment(iter , next + 1); 00762 #ifdef CLASSIFIER_TEST 00763 if(loss < min_gini_ && !closeAtTolerance(loss, min_gini_)) 00764 #else 00765 if(loss < min_gini_ ) 00766 #endif 00767 { 00768 bestCurrentCounts[0] = left.response(); 00769 bestCurrentCounts[1] = right.response(); 00770 #ifdef CLASSIFIER_TEST 00771 min_gini_ = loss < min_gini_? loss : min_gini_; 00772 #else 00773 min_gini_ = loss; 00774 #endif 00775 min_index_ = next - begin +1 ; 00776 min_threshold_ = (double(column(*next,g)) + double(column(*(next +1), g)))/2.0; 00777 } 00778 iter = next +1 ; 00779 next = std::adjacent_find(iter, end, comp); 00780 } 00781 } 00782 00783 template<class DataSource_t, class Iter, class Array> 00784 double loss_of_region(DataSource_t const & labels, 00785 Iter & begin, 00786 Iter & end, 00787 Array const & region_response) const 00788 { 00789 typedef typename 00790 LossTraits<LineSearchLossTag, DataSource_t>::type LineSearchLoss; 00791 LineSearchLoss region_loss(labels, ext_param_); 00792 return 00793 region_loss.init(begin, end, region_response); 00794 } 00795 00796 }; 00797 00798 00799 /** Chooses mtry columns ad applys ColumnDecisionFunctor to each of the 00800 * columns. Then Chooses the column that is best 00801 */ 00802 template<class ColumnDecisionFunctor, class Tag = ClassificationTag> 00803 class ThresholdSplit: public SplitBase<Tag> 00804 { 00805 public: 00806 00807 00808 typedef SplitBase<Tag> SB; 00809 00810 ArrayVector<Int32> splitColumns; 00811 ColumnDecisionFunctor bgfunc; 00812 00813 double region_gini_; 00814 ArrayVector<double> min_gini_; 00815 ArrayVector<ptrdiff_t> min_indices_; 00816 ArrayVector<double> min_thresholds_; 00817 00818 int bestSplitIndex; 00819 00820 double minGini() const 00821 { 00822 return min_gini_[bestSplitIndex]; 00823 } 00824 int bestSplitColumn() const 00825 { 00826 return splitColumns[bestSplitIndex]; 00827 } 00828 double bestSplitThreshold() const 00829 { 00830 return min_thresholds_[bestSplitIndex]; 00831 } 00832 00833 template<class T> 00834 void set_external_parameters(ProblemSpec<T> const & in) 00835 { 00836 SB::set_external_parameters(in); 00837 bgfunc.set_external_parameters( SB::ext_param_); 00838 int featureCount_ = SB::ext_param_.column_count_; 00839 splitColumns.resize(featureCount_); 00840 for(int k=0; k<featureCount_; ++k) 00841 splitColumns[k] = k; 00842 min_gini_.resize(featureCount_); 00843 min_indices_.resize(featureCount_); 00844 min_thresholds_.resize(featureCount_); 00845 } 00846 00847 00848 template<class T, class C, class T2, class C2, class Region, class Random> 00849 int findBestSplit(MultiArrayView<2, T, C> features, 00850 MultiArrayView<2, T2, C2> labels, 00851 Region & region, 00852 ArrayVector<Region>& childRegions, 00853 Random & randint) 00854 { 00855 00856 typedef typename Region::IndexIterator IndexIterator; 00857 if(region.size() == 0) 00858 { 00859 std::cerr << "SplitFunctor::findBestSplit(): stackentry with 0 examples encountered\n" 00860 "continuing learning process...."; 00861 } 00862 // calculate things that haven't been calculated yet. 00863 00864 if(std::accumulate(region.classCounts().begin(), 00865 region.classCounts().end(), 0) != region.size()) 00866 { 00867 RandomForestClassCounter< MultiArrayView<2,T2, C2>, 00868 ArrayVector<double> > 00869 counter(labels, region.classCounts()); 00870 std::for_each( region.begin(), region.end(), counter); 00871 region.classCountsIsValid = true; 00872 } 00873 00874 // Is the region pure already? 00875 region_gini_ = bgfunc.loss_of_region(labels, 00876 region.begin(), 00877 region.end(), 00878 region.classCounts()); 00879 if(region_gini_ <= SB::ext_param_.precision_) 00880 return makeTerminalNode(features, labels, region, randint); 00881 00882 // select columns to be tried. 00883 for(int ii = 0; ii < SB::ext_param_.actual_mtry_; ++ii) 00884 std::swap(splitColumns[ii], 00885 splitColumns[ii+ randint(features.shape(1) - ii)]); 00886 00887 // find the best gini index 00888 bestSplitIndex = 0; 00889 double current_min_gini = region_gini_; 00890 int num2try = features.shape(1); 00891 for(int k=0; k<num2try; ++k) 00892 { 00893 //this functor does all the work 00894 bgfunc(features, 00895 splitColumns[k], 00896 labels, 00897 region.begin(), region.end(), 00898 region.classCounts()); 00899 min_gini_[k] = bgfunc.min_gini_; 00900 min_indices_[k] = bgfunc.min_index_; 00901 min_thresholds_[k] = bgfunc.min_threshold_; 00902 #ifdef CLASSIFIER_TEST 00903 if( bgfunc.min_gini_ < current_min_gini 00904 && !closeAtTolerance(bgfunc.min_gini_, current_min_gini)) 00905 #else 00906 if(bgfunc.min_gini_ < current_min_gini) 00907 #endif 00908 { 00909 current_min_gini = bgfunc.min_gini_; 00910 childRegions[0].classCounts() = bgfunc.bestCurrentCounts[0]; 00911 childRegions[1].classCounts() = bgfunc.bestCurrentCounts[1]; 00912 childRegions[0].classCountsIsValid = true; 00913 childRegions[1].classCountsIsValid = true; 00914 00915 bestSplitIndex = k; 00916 num2try = SB::ext_param_.actual_mtry_; 00917 } 00918 } 00919 00920 // did not find any suitable split 00921 if(closeAtTolerance(current_min_gini, region_gini_)) 00922 return makeTerminalNode(features, labels, region, randint); 00923 00924 //create a Node for output 00925 Node<i_ThresholdNode> node(SB::t_data, SB::p_data); 00926 SB::node_ = node; 00927 node.threshold() = min_thresholds_[bestSplitIndex]; 00928 node.column() = splitColumns[bestSplitIndex]; 00929 00930 // partition the range according to the best dimension 00931 SortSamplesByDimensions<MultiArrayView<2, T, C> > 00932 sorter(features, node.column(), node.threshold()); 00933 IndexIterator bestSplit = 00934 std::partition(region.begin(), region.end(), sorter); 00935 // Save the ranges of the child stack entries. 00936 childRegions[0].setRange( region.begin() , bestSplit ); 00937 childRegions[0].rule = region.rule; 00938 childRegions[0].rule.push_back(std::make_pair(1, 1.0)); 00939 childRegions[1].setRange( bestSplit , region.end() ); 00940 childRegions[1].rule = region.rule; 00941 childRegions[1].rule.push_back(std::make_pair(1, 1.0)); 00942 00943 return i_ThresholdNode; 00944 } 00945 }; 00946 00947 typedef ThresholdSplit<BestGiniOfColumn<GiniCriterion> > GiniSplit; 00948 typedef ThresholdSplit<BestGiniOfColumn<EntropyCriterion> > EntropySplit; 00949 typedef ThresholdSplit<BestGiniOfColumn<LSQLoss>, RegressionTag> RegressionSplit; 00950 00951 namespace rf 00952 { 00953 00954 /** This namespace contains additional Splitfunctors. 00955 * 00956 * The Split functor classes are designed in a modular fashion because new split functors may 00957 * share a lot of code with existing ones. 00958 * 00959 * ThresholdSplit implements the functionality needed for any split functor, that makes its 00960 * decision via one dimensional axis-parallel cuts. The Template parameter defines how the split 00961 * along one dimension is chosen. 00962 * 00963 * The BestGiniOfColumn class chooses a split that minimizes one of the Loss functions supplied 00964 * (GiniCriterion for classification and LSQLoss for regression). Median chooses the Split in a 00965 * kD tree fashion. 00966 * 00967 * 00968 * Currently defined typedefs: 00969 * \code 00970 * typedef ThresholdSplit<BestGiniOfColumn<GiniCriterion> > GiniSplit; 00971 * typedef ThresholdSplit<BestGiniOfColumn<LSQLoss>, RegressionTag> RegressionSplit; 00972 * typedef ThresholdSplit<Median> MedianSplit; 00973 * \endcode 00974 */ 00975 namespace split 00976 { 00977 00978 /** This Functor chooses the median value of a column 00979 */ 00980 class Median 00981 { 00982 public: 00983 00984 typedef GiniCriterion LineSearchLossTag; 00985 ArrayVector<double> class_weights_; 00986 ArrayVector<double> bestCurrentCounts[2]; 00987 double min_gini_; 00988 ptrdiff_t min_index_; 00989 double min_threshold_; 00990 ProblemSpec<> ext_param_; 00991 00992 Median() 00993 {} 00994 00995 template<class T> 00996 Median(ProblemSpec<T> const & ext) 00997 : 00998 class_weights_(ext.class_weights_), 00999 ext_param_(ext) 01000 { 01001 bestCurrentCounts[0].resize(ext.class_count_); 01002 bestCurrentCounts[1].resize(ext.class_count_); 01003 } 01004 01005 template<class T> 01006 void set_external_parameters(ProblemSpec<T> const & ext) 01007 { 01008 class_weights_ = ext.class_weights_; 01009 ext_param_ = ext; 01010 bestCurrentCounts[0].resize(ext.class_count_); 01011 bestCurrentCounts[1].resize(ext.class_count_); 01012 } 01013 01014 template< class DataSourceF_t, 01015 class DataSource_t, 01016 class I_Iter, 01017 class Array> 01018 void operator()(DataSourceF_t const & column, 01019 DataSource_t const & labels, 01020 I_Iter & begin, 01021 I_Iter & end, 01022 Array const & region_response) 01023 { 01024 std::sort(begin, end, 01025 SortSamplesByDimensions<DataSourceF_t>(column, 0)); 01026 typedef typename 01027 LossTraits<LineSearchLossTag, DataSource_t>::type LineSearchLoss; 01028 LineSearchLoss left(labels, ext_param_); 01029 LineSearchLoss right(labels, ext_param_); 01030 right.init(begin, end, region_response); 01031 01032 min_gini_ = NumericTraits<double>::max(); 01033 min_index_ = floor(double(end - begin)/2.0); 01034 min_threshold_ = column[*(begin + min_index_)]; 01035 SortSamplesByDimensions<DataSourceF_t> 01036 sorter(column, 0, min_threshold_); 01037 I_Iter part = std::partition(begin, end, sorter); 01038 DimensionNotEqual<DataSourceF_t> comp(column, 0); 01039 if(part == begin) 01040 { 01041 part= std::adjacent_find(part, end, comp)+1; 01042 01043 } 01044 if(part >= end) 01045 { 01046 return; 01047 } 01048 else 01049 { 01050 min_threshold_ = column[*part]; 01051 } 01052 min_gini_ = right.decrement(begin, part) 01053 + left.increment(begin , part); 01054 01055 bestCurrentCounts[0] = left.response(); 01056 bestCurrentCounts[1] = right.response(); 01057 01058 min_index_ = part - begin; 01059 } 01060 01061 template<class DataSource_t, class Iter, class Array> 01062 double loss_of_region(DataSource_t const & labels, 01063 Iter & begin, 01064 Iter & end, 01065 Array const & region_response) const 01066 { 01067 typedef typename 01068 LossTraits<LineSearchLossTag, DataSource_t>::type LineSearchLoss; 01069 LineSearchLoss region_loss(labels, ext_param_); 01070 return 01071 region_loss.init(begin, end, region_response); 01072 } 01073 01074 }; 01075 01076 typedef ThresholdSplit<Median> MedianSplit; 01077 01078 01079 /** This Functor chooses a random value of a column 01080 */ 01081 class RandomSplitOfColumn 01082 { 01083 public: 01084 01085 typedef GiniCriterion LineSearchLossTag; 01086 ArrayVector<double> class_weights_; 01087 ArrayVector<double> bestCurrentCounts[2]; 01088 double min_gini_; 01089 ptrdiff_t min_index_; 01090 double min_threshold_; 01091 ProblemSpec<> ext_param_; 01092 typedef RandomMT19937 Random_t; 01093 Random_t random; 01094 01095 RandomSplitOfColumn() 01096 {} 01097 01098 template<class T> 01099 RandomSplitOfColumn(ProblemSpec<T> const & ext) 01100 : 01101 class_weights_(ext.class_weights_), 01102 ext_param_(ext), 01103 random(RandomSeed) 01104 { 01105 bestCurrentCounts[0].resize(ext.class_count_); 01106 bestCurrentCounts[1].resize(ext.class_count_); 01107 } 01108 01109 template<class T> 01110 RandomSplitOfColumn(ProblemSpec<T> const & ext, Random_t & random_) 01111 : 01112 class_weights_(ext.class_weights_), 01113 ext_param_(ext), 01114 random(random_) 01115 { 01116 bestCurrentCounts[0].resize(ext.class_count_); 01117 bestCurrentCounts[1].resize(ext.class_count_); 01118 } 01119 01120 template<class T> 01121 void set_external_parameters(ProblemSpec<T> const & ext) 01122 { 01123 class_weights_ = ext.class_weights_; 01124 ext_param_ = ext; 01125 bestCurrentCounts[0].resize(ext.class_count_); 01126 bestCurrentCounts[1].resize(ext.class_count_); 01127 } 01128 01129 template< class DataSourceF_t, 01130 class DataSource_t, 01131 class I_Iter, 01132 class Array> 01133 void operator()(DataSourceF_t const & column, 01134 DataSource_t const & labels, 01135 I_Iter & begin, 01136 I_Iter & end, 01137 Array const & region_response) 01138 { 01139 std::sort(begin, end, 01140 SortSamplesByDimensions<DataSourceF_t>(column, 0)); 01141 typedef typename 01142 LossTraits<LineSearchLossTag, DataSource_t>::type LineSearchLoss; 01143 LineSearchLoss left(labels, ext_param_); 01144 LineSearchLoss right(labels, ext_param_); 01145 right.init(begin, end, region_response); 01146 01147 01148 min_gini_ = NumericTraits<double>::max(); 01149 01150 min_index_ = begin + random.uniformInt(end -begin); 01151 min_threshold_ = column[*(begin + min_index_)]; 01152 SortSamplesByDimensions<DataSourceF_t> 01153 sorter(column, 0, min_threshold_); 01154 I_Iter part = std::partition(begin, end, sorter); 01155 DimensionNotEqual<DataSourceF_t> comp(column, 0); 01156 if(part == begin) 01157 { 01158 part= std::adjacent_find(part, end, comp)+1; 01159 01160 } 01161 if(part >= end) 01162 { 01163 return; 01164 } 01165 else 01166 { 01167 min_threshold_ = column[*part]; 01168 } 01169 min_gini_ = right.decrement(begin, part) 01170 + left.increment(begin , part); 01171 01172 bestCurrentCounts[0] = left.response(); 01173 bestCurrentCounts[1] = right.response(); 01174 01175 min_index_ = part - begin; 01176 } 01177 01178 template<class DataSource_t, class Iter, class Array> 01179 double loss_of_region(DataSource_t const & labels, 01180 Iter & begin, 01181 Iter & end, 01182 Array const & region_response) const 01183 { 01184 typedef typename 01185 LossTraits<LineSearchLossTag, DataSource_t>::type LineSearchLoss; 01186 LineSearchLoss region_loss(labels, ext_param_); 01187 return 01188 region_loss.init(begin, end, region_response); 01189 } 01190 01191 }; 01192 01193 typedef ThresholdSplit<RandomSplitOfColumn> RandomSplit; 01194 } 01195 } 01196 01197 01198 } //namespace vigra 01199 #endif // VIGRA_RANDOM_FOREST_SPLIT_HXX
© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de) |
html generated using doxygen and Python
|