[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]
00001 /************************************************************************/ 00002 /* */ 00003 /* Copyright 2008-2009 by Ullrich Koethe and Rahul Nair */ 00004 /* */ 00005 /* This file is part of the VIGRA computer vision library. */ 00006 /* The VIGRA Website is */ 00007 /* http://hci.iwr.uni-heidelberg.de/vigra/ */ 00008 /* Please direct questions, bug reports, and contributions to */ 00009 /* ullrich.koethe@iwr.uni-heidelberg.de or */ 00010 /* vigra@informatik.uni-hamburg.de */ 00011 /* */ 00012 /* Permission is hereby granted, free of charge, to any person */ 00013 /* obtaining a copy of this software and associated documentation */ 00014 /* files (the "Software"), to deal in the Software without */ 00015 /* restriction, including without limitation the rights to use, */ 00016 /* copy, modify, merge, publish, distribute, sublicense, and/or */ 00017 /* sell copies of the Software, and to permit persons to whom the */ 00018 /* Software is furnished to do so, subject to the following */ 00019 /* conditions: */ 00020 /* */ 00021 /* The above copyright notice and this permission notice shall be */ 00022 /* included in all copies or substantial portions of the */ 00023 /* Software. */ 00024 /* */ 00025 /* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND */ 00026 /* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES */ 00027 /* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND */ 00028 /* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT */ 00029 /* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, */ 00030 /* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING */ 00031 /* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR */ 00032 /* OTHER DEALINGS IN THE SOFTWARE. */ 00033 /* */ 00034 /************************************************************************/ 00035 00036 00037 #ifndef VIGRA_RANDOM_FOREST_HXX 00038 #define VIGRA_RANDOM_FOREST_HXX 00039 00040 #include <iostream> 00041 #include <algorithm> 00042 #include <map> 00043 #include <set> 00044 #include <list> 00045 #include <numeric> 00046 #include "mathutil.hxx" 00047 #include "array_vector.hxx" 00048 #include "sized_int.hxx" 00049 #include "matrix.hxx" 00050 #include "random.hxx" 00051 #include "functorexpression.hxx" 00052 #include "random_forest/rf_common.hxx" 00053 #include "random_forest/rf_nodeproxy.hxx" 00054 #include "random_forest/rf_split.hxx" 00055 #include "random_forest/rf_decisionTree.hxx" 00056 #include "random_forest/rf_visitors.hxx" 00057 #include "random_forest/rf_region.hxx" 00058 #include "random_forest/rf_sampling.hxx" 00059 #include "random_forest/rf_preprocessing.hxx" 00060 #include "random_forest/rf_online_prediction_set.hxx" 00061 #include "random_forest/rf_earlystopping.hxx" 00062 namespace vigra 00063 { 00064 00065 /** \addtogroup MachineLearning Machine Learning 00066 00067 This module provides classification algorithms that map 00068 features to labels or label probablities. 00069 **/ 00070 //@{ 00071 00072 namespace detail 00073 { 00074 00075 /* todo - remove and make the labels parameter in the sampling options 00076 * const*/ 00077 class staticMultiArrayViewHelper 00078 { 00079 public: 00080 static vigra::MultiArrayView<2, Int32> array; 00081 public: 00082 friend SamplingOptions 00083 createSamplingOptions(vigra::RandomForestOptions& RF_opt, 00084 vigra::MultiArrayView<2, int> & labels); 00085 }; 00086 00087 00088 /* \brief sampling option factory function 00089 */ 00090 SamplingOptions make_sampler_opt ( RF_Traits::Options_t & RF_opt, 00091 MultiArrayView<2, Int32> & labels 00092 = staticMultiArrayViewHelper::array) 00093 { 00094 SamplingOptions return_opt; 00095 return_opt.sample_with_replacement = RF_opt.sample_with_replacement_; 00096 if(labels.data() != 0) 00097 { 00098 if(RF_opt.stratification_method_ == RF_EQUAL) 00099 return_opt 00100 .sampleClassesIndividually( 00101 ArrayVectorView<int>(labels.size(), 00102 labels.data())); 00103 else if(RF_opt.stratification_method_ == RF_PROPORTIONAL) 00104 return_opt 00105 .sampleStratified( 00106 ArrayVectorView<int>(labels.size(), 00107 labels.data())); 00108 } 00109 return return_opt; 00110 } 00111 }//namespace detail 00112 00113 /** Random Forest class 00114 * 00115 * \tparam <PrprocessorTag = ClassificationTag> Class used to preprocess 00116 * the input while learning and predicting. Currently Available: 00117 * ClassificationTag and RegressionTag. It is recommended to use 00118 * Splitfunctor::Preprocessor_t while using custom splitfunctors 00119 * as they may need the data to be in a different format. 00120 * \sa Preprocessor, How to make a Split Functor 00121 * 00122 * 00123 * 00124 */ 00125 template <class LabelType , class PreprocessorTag > 00126 class RandomForest 00127 { 00128 00129 public: 00130 //public typedefs 00131 typedef RF_Traits::Options_t Options_t; 00132 typedef RF_Traits::DecisionTree_t DecisionTree_t; 00133 typedef ProblemSpec<LabelType> ProblemSpec_t; 00134 typedef RF_Traits::Default_Split_t Default_Split_t; 00135 typedef RF_Traits::Default_Stop_t Default_Stop_t; 00136 typedef RF_Traits::Default_Visitor_t Default_Visitor_t; 00137 typedef LabelType LabelT; 00138 protected: 00139 00140 /** optimisation for predictLabels 00141 * */ 00142 mutable MultiArray<2, double> garbage_prediction_; 00143 00144 public: 00145 00146 //problem independent data. 00147 Options_t options_; 00148 //problem dependent data members - is only set if 00149 //a copy constructor, some sort of import 00150 //function or the learn function is called 00151 ArrayVector<DecisionTree_t> 00152 trees_; 00153 ProblemSpec_t ext_param_; 00154 mutable ArrayVector<int> tree_indices_; 00155 OnlineLearnVisitor online_visitor_; 00156 00157 00158 void reset() 00159 { 00160 ext_param_.clear(); 00161 trees_.clear(); 00162 } 00163 00164 public: 00165 00166 /** \name Contructors 00167 * Note: No copy Constructor specified as no pointers are manipulated 00168 * in this class 00169 */ 00170 /*\{*/ 00171 /**\brief default constructor 00172 * 00173 * \param options general options to the Random Forest. Must be of Type 00174 * Options_t 00175 * \param ext_param problem specific values that can be supplied 00176 * additionally. (class weights , labels etc) 00177 * \sa ProblemSpec_t 00178 * 00179 * 00180 * simple usage for classification (regression is not yet supported): 00181 * \code 00182 * typedef xxx feature_t \\ replace xxx with whichever type 00183 * typedef yyy label_t \\ meme chose. 00184 * MultiArrayView<2, feature_t> f = get_some_features(); 00185 * MultiArrayView<2, label_t> l = get_some_labels)( 00186 * RandomForest<> rf() 00187 * double oob_error = rf.learn(f, l); 00188 * 00189 * MultiArrayView<2, feature_t> pf = get_some_unknown_features(); 00190 * MultiArrayView<2, label_t> prediction 00191 * = allocate_space_for_response(); 00192 * MultiArrayView<2, double> prob = allocate_space_for_probability(); 00193 * 00194 * rf.predict_labels(pf, prediction); 00195 * rf.predict_probabilities(pf, prob); 00196 * 00197 * \endcode 00198 * 00199 * - Default Response/Label type is double 00200 */ 00201 RandomForest(Options_t const & options = Options_t(), 00202 ProblemSpec_t const & ext_param = ProblemSpec_t()) 00203 : 00204 options_(options), 00205 ext_param_(ext_param), 00206 tree_indices_(options.tree_count_,0) 00207 { 00208 for(int ii = 0 ; ii < int(tree_indices_.size()); ++ii) 00209 tree_indices_[ii] = ii; 00210 } 00211 00212 /**\brief Create RF from external source 00213 * \param treeCount Number of trees to add. 00214 * \param trees Iterator to a Container where the topology_ data 00215 * of the trees are stored. 00216 * \param weights iterator to a Container where the parameters_ data 00217 * of the trees are stored. 00218 * \param problem_spec 00219 * Extrinsic parameters that specify the problem e.g. 00220 * ClassCount, featureCount etc. 00221 * \param options (optional) specify options used to train the original 00222 * Random forest. This parameter is not used anywhere 00223 * during prediction and thus is optional. 00224 * 00225 * TODO: 00226 * Note: This constructor may be replaced by a Constructor using 00227 * NodeProxy iterators to encapsulate the underlying data type. 00228 */ 00229 template<class TreeIterator, class WeightIterator> 00230 RandomForest(int treeCount, 00231 TreeIterator trees, 00232 WeightIterator weights, 00233 ProblemSpec_t const & problem_spec, 00234 Options_t const & options = Options_t()) 00235 : 00236 trees_(treeCount, DecisionTree_t(problem_spec)), 00237 ext_param_(problem_spec), 00238 options_(options) 00239 { 00240 for(unsigned int k=0; k<treeCount; ++k, ++trees, ++weights) 00241 { 00242 trees_[k].topology_ = *trees; 00243 trees_[k].parameters_ = *weights; 00244 } 00245 } 00246 00247 /*\}*/ 00248 00249 00250 /** \name Data Access 00251 * data access interface - usage of member objects is deprecated 00252 * (I like the word deprecated) 00253 */ 00254 00255 /*\{*/ 00256 00257 00258 /**\brief return external parameters for viewing 00259 * \return ProblemSpec_t 00260 */ 00261 ProblemSpec_t const & ext_param() const 00262 { 00263 vigra_precondition(ext_param_.used() == true, 00264 "RandomForest::ext_param(): " 00265 "Random forest has not been trained yet."); 00266 return ext_param_; 00267 } 00268 00269 /**\brief set external parameters 00270 * 00271 * \param in external parameters to be set 00272 * 00273 * set external parameters explicitly. 00274 * If Random Forest has not been trained the preprocessor will 00275 * either ignore filling values set this way or will throw an exception 00276 * if values specified manually do not match the value calculated 00277 & during the preparation step. 00278 * \sa Option_t::presupplied_ext_param member for further details. 00279 */ 00280 void set_ext_param(ProblemSpec_t const & in) 00281 { 00282 vigra_precondition(ext_param_.used() == false, 00283 "RandomForest::set_ext_param():" 00284 "Random forest has been trained! Call reset()" 00285 "before specifying new extrinsic parameters."); 00286 } 00287 00288 /**\brief access random forest options 00289 * 00290 * \return random forest options 00291 */ 00292 Options_t & set_options() 00293 { 00294 return options; 00295 } 00296 00297 00298 /**\brief access const random forest options 00299 * 00300 * \return const Option_t 00301 */ 00302 Options_t const & options() const 00303 { 00304 return options_; 00305 } 00306 00307 /**\brief access const trees 00308 */ 00309 DecisionTree_t const & tree(int index) const 00310 { 00311 return trees_[index]; 00312 } 00313 00314 /**\brief access trees 00315 */ 00316 DecisionTree_t & tree(int index) 00317 { 00318 return trees_[index]; 00319 } 00320 00321 /*\}*/ 00322 int column_count() const 00323 { 00324 return ext_param_.column_count_; 00325 } 00326 00327 int class_count() const 00328 { 00329 return ext_param_.class_count_; 00330 } 00331 00332 int tree_count() const 00333 { 00334 return options_.tree_count_; 00335 } 00336 00337 /**\name Learning 00338 * Following functions differ in the degree of customization 00339 * allowed 00340 */ 00341 /*\{*/ 00342 /**\brief learn on data with custom config and random number generator 00343 * 00344 * \param features a N x M matrix containing N samples with M 00345 * features 00346 * \param response a N x D matrix containing the corresponding 00347 * response. Current split functors assume D to 00348 * be 1 and ignore any additional columns. 00349 * This is not enforced to allow future support 00350 * for uncertain labels, label independent strata etc. 00351 * The Preprocessor specified during construction 00352 * should be able to handle features and labels 00353 * features and the labels. 00354 * \sa SplitFunctor, Preprocessing 00355 * 00356 * \param visitor visitor which is to be applied after each split, 00357 * tree and at the end. Use RF_Default for using 00358 * default value. 00359 * \sa visitor 00360 * \param split split functor to be used to calculate each split 00361 * use rf_default() for using default value. 00362 * \param stop 00363 * predicate to be used to calculate each split 00364 * use rf_default() for using default value. 00365 * \param random RandomNumberGenerator to be used. Use 00366 * rf_default() to use default value. 00367 * \return oob_error. 00368 * 00369 *\sa OOB_Visitor, VariableImportanceVisitor 00370 * 00371 */ 00372 template <class U, class C1, 00373 class U2,class C2, 00374 class Split_t, 00375 class Stop_t, 00376 class Visitor_t, 00377 class Random_t> 00378 double learn( MultiArrayView<2, U, C1> const & features, 00379 MultiArrayView<2, U2,C2> const & response, 00380 Visitor_t visitor, 00381 Split_t split, 00382 Stop_t stop, 00383 Random_t const & random); 00384 00385 template <class U, class C1, 00386 class U2,class C2, 00387 class Split_t, 00388 class Stop_t, 00389 class Visitor_t> 00390 double learn( MultiArrayView<2, U, C1> const & features, 00391 MultiArrayView<2, U2,C2> const & response, 00392 Visitor_t visitor, 00393 Split_t split, 00394 Stop_t stop) 00395 00396 { 00397 RandomNumberGenerator<> rnd = RandomNumberGenerator<>(RandomSeed); 00398 return learn(features, response, 00399 visitor, split, stop, 00400 rnd); 00401 } 00402 00403 00404 template<class U,class C1, 00405 class U2, class C2, 00406 class Split_t, 00407 class Stop_t, 00408 class Visitor_t, 00409 class Random_t> 00410 double onlineLearn(MultiArrayView<2,U,C1> const & features, 00411 MultiArrayView<2,U2,C2> const & response, 00412 int new_start_index, 00413 Visitor_t visitor_, 00414 Split_t split_, 00415 Stop_t stop_, 00416 Random_t & random, 00417 bool adjust_thresholds=false); 00418 00419 template <class U, class C1, class U2,class C2> 00420 double onlineLearn( MultiArrayView<2, U, C1> const & features, 00421 MultiArrayView<2, U2,C2> const & labels,int new_start_index,bool adjust_thresholds=false) 00422 { 00423 RandomNumberGenerator<> rnd = RandomNumberGenerator<>(RandomSeed); 00424 return onlineLearn(features, 00425 labels, 00426 new_start_index, 00427 rf_default(), 00428 rf_default(), 00429 rf_default(), 00430 rnd, 00431 adjust_thresholds); 00432 } 00433 00434 template<class U,class C1, 00435 class U2, class C2, 00436 class Split_t, 00437 class Stop_t, 00438 class Visitor_t, 00439 class Random_t> 00440 void reLearnTree(MultiArrayView<2,U,C1> const & features, 00441 MultiArrayView<2,U2,C2> const & response, 00442 int treeId, 00443 Visitor_t visitor_, 00444 Split_t split_, 00445 Stop_t stop_, 00446 Random_t & random); 00447 00448 template<class U, class C1, class U2, class C2> 00449 void reLearnTree(MultiArrayView<2, U, C1> const & features, 00450 MultiArrayView<2, U2, C2> const & labels, 00451 int treeId) 00452 { 00453 RandomNumberGenerator<> rnd = RandomNumberGenerator<>(RandomSeed); 00454 reLearnTree(features, 00455 labels, 00456 treeId, 00457 rf_default(), 00458 rf_default(), 00459 rf_default(), 00460 rnd); 00461 } 00462 00463 00464 00465 template <class U, class C1, class U2,class C2, class Visitor_t> 00466 double learn( MultiArrayView<2, U, C1> const & features, 00467 MultiArrayView<2, U2,C2> const & labels, 00468 Visitor_t visitor) 00469 { 00470 return learn(features, 00471 labels, 00472 visitor, 00473 rf_default(), 00474 rf_default()); 00475 } 00476 00477 template <class U, class C1, class U2,class C2, 00478 class Visitor_t, class Split_t> 00479 double learn( MultiArrayView<2, U, C1> const & features, 00480 MultiArrayView<2, U2,C2> const & labels, 00481 Visitor_t visitor, 00482 Split_t split) 00483 { 00484 return learn(features, 00485 labels, 00486 visitor, 00487 split, 00488 rf_default()); 00489 } 00490 00491 /**\brief learn on data with default configuration 00492 * 00493 * \param features a N x M matrix containing N samples with M 00494 * features 00495 * \param labels a N x D matrix containing the corresponding 00496 * N labels. Current split functors assume D to 00497 * be 1 and ignore any additional columns. 00498 * this is not enforced to allow future support 00499 * for uncertain labels. 00500 * \return out of bag error estimate. 00501 * 00502 * learning is done with: 00503 * 00504 * \sa GiniSplit, EarlyStoppingStd, OOB_Visitor 00505 * 00506 * - Randomly seeded random number generator 00507 * - default gini split functor as described by Breiman 00508 * - default The standard early stopping criterion 00509 * - the oob visitor, whose value is returned. 00510 */ 00511 template <class U, class C1, class U2,class C2> 00512 double learn( MultiArrayView<2, U, C1> const & features, 00513 MultiArrayView<2, U2,C2> const & labels) 00514 { 00515 return learn(features, 00516 labels, 00517 rf_default(), 00518 rf_default(), 00519 rf_default()); 00520 } 00521 /*\}*/ 00522 00523 00524 00525 /**\name prediction 00526 */ 00527 /*\{*/ 00528 /** \brief predict a label given a feature. 00529 * 00530 * \param features: a 1 by featureCount matrix containing 00531 * data point to be predicted (this only works in 00532 * classification setting) 00533 * \param stop: early stopping critierion 00534 * \return double value representing class. You can use the 00535 * predictLabels() function together with the 00536 * rf.external_parameter().class_type_ attribute 00537 * to get back the same type used during learning. 00538 */ 00539 template <class U, class C, class Stop> 00540 LabelType predictLabel(MultiArrayView<2, U, C>const & features, Stop & stop) const; 00541 00542 template <class U, class C> 00543 LabelType predictLabel(MultiArrayView<2, U, C>const & features) 00544 { 00545 return predictLabel(features, rf_default()); 00546 } 00547 /** \brief predict a label with features and class priors 00548 * 00549 * \param features: same as above. 00550 * \param prior: iterator to prior weighting of classes 00551 * \return sam as above. 00552 */ 00553 template <class U, class C> 00554 LabelType predictLabel(MultiArrayView<2, U, C> const & features, 00555 ArrayVectorView<double> prior) const; 00556 00557 /** \brief predict multiple labels with given features 00558 * 00559 * \param features: a n by featureCount matrix containing 00560 * data point to be predicted (this only works in 00561 * classification setting) 00562 * \param labels: a n by 1 matrix passed by reference to store 00563 * output. 00564 */ 00565 template <class U, class C1, class T, class C2> 00566 void predictLabels(MultiArrayView<2, U, C1>const & features, 00567 MultiArrayView<2, T, C2> & labels) const 00568 { 00569 vigra_precondition(features.shape(0) == labels.shape(0), 00570 "RandomForest::predictLabels(): Label array has wrong size."); 00571 for(int k=0; k<features.shape(0); ++k) 00572 labels(k,0) = detail::RequiresExplicitCast<T>::cast(predictLabel(rowVector(features, k), rf_default())); 00573 } 00574 00575 template <class U, class C1, class T, class C2, class Stop> 00576 void predictLabels(MultiArrayView<2, U, C1>const & features, 00577 MultiArrayView<2, T, C2> & labels, 00578 Stop & stop) const 00579 { 00580 vigra_precondition(features.shape(0) == labels.shape(0), 00581 "RandomForest::predictLabels(): Label array has wrong size."); 00582 for(int k=0; k<features.shape(0); ++k) 00583 labels(k,0) = detail::RequiresExplicitCast<T>::cast(predictLabel(rowVector(features, k), stop)); 00584 } 00585 /** \brief predict the class probabilities for multiple labels 00586 * 00587 * \param features same as above 00588 * \param prob a n x class_count_ matrix. passed by reference to 00589 * save class probabilities 00590 * \param stop earlystopping criterion 00591 * \sa EarlyStopping 00592 */ 00593 template <class U, class C1, class T, class C2, class Stop> 00594 void predictProbabilities(MultiArrayView<2, U, C1>const & features, 00595 MultiArrayView<2, T, C2> & prob, 00596 Stop & stop) const; 00597 template <class T1,class T2, class C> 00598 void predictProbabilities(OnlinePredictionSet<T1> & predictionSet, 00599 MultiArrayView<2, T2, C> & prob); 00600 00601 /** \brief predict the class probabilities for multiple labels 00602 * 00603 * \param features same as above 00604 * \param prob a n x class_count_ matrix. passed by reference to 00605 * save class probabilities 00606 */ 00607 template <class U, class C1, class T, class C2> 00608 void predictProbabilities(MultiArrayView<2, U, C1>const & features, 00609 MultiArrayView<2, T, C2> & prob) const 00610 { 00611 predictProbabilities(features, prob, rf_default()); 00612 } 00613 00614 00615 /*\}*/ 00616 00617 }; 00618 00619 00620 template <class LabelType, class PreprocessorTag> 00621 template<class U,class C1, 00622 class U2, class C2, 00623 class Split_t, 00624 class Stop_t, 00625 class Visitor_t, 00626 class Random_t> 00627 double RandomForest<LabelType, PreprocessorTag>::onlineLearn(MultiArrayView<2,U,C1> const & features, 00628 MultiArrayView<2,U2,C2> const & response, 00629 int new_start_index, 00630 Visitor_t visitor_, 00631 Split_t split_, 00632 Stop_t stop_, 00633 Random_t & random, 00634 bool adjust_thresholds) 00635 { 00636 online_visitor_.activate(); 00637 online_visitor_.adjust_thresholds=adjust_thresholds; 00638 00639 using namespace rf; 00640 //typedefs 00641 typedef typename Split_t::StackEntry_t StackEntry_t; 00642 typedef Processor<PreprocessorTag,LabelType,U,C1,U2,C2> Preprocessor_t; 00643 typedef UniformIntRandomFunctor<Random_t> 00644 RandFunctor_t; 00645 // default values and initialization 00646 // Value Chooser chooses second argument as value if first argument 00647 // is of type RF_DEFAULT. (thanks to template magic - don't care about 00648 // it - just smile and wave. 00649 00650 #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_> 00651 Default_Stop_t default_stop(options_); 00652 typename RF_CHOOSER(Stop_t)::type stop 00653 = RF_CHOOSER(Stop_t)::choose(stop_, default_stop); 00654 Default_Split_t default_split; 00655 typename RF_CHOOSER(Split_t)::type split 00656 = RF_CHOOSER(Split_t)::choose(split_, default_split); 00657 StopVisiting stopvisiting; 00658 OOB_Visitor oob; 00659 typedef VisitorNode<OnlineLearnVisitor, typename RF_CHOOSER(Visitor_t)::type> IntermedVis; 00660 IntermedVis 00661 inter(online_visitor_, RF_CHOOSER(Visitor_t)::choose(visitor_, stopvisiting)); 00662 VisitorNode<OOB_Visitor, IntermedVis> 00663 visitor(oob,inter); 00664 #undef RF_CHOOSER 00665 00666 // Preprocess the data to get something the split functor can work 00667 // with. Also fill the ext_param structure by preprocessing 00668 // option parameters that could only be completely evaluated 00669 // when the training data is known. 00670 ext_param_.class_count_=0; 00671 Preprocessor_t preprocessor( features, response, 00672 options_, ext_param_); 00673 00674 // Make stl compatible random functor. 00675 RandFunctor_t randint ( random); 00676 00677 // Give the Split functor information about the data. 00678 split.set_external_parameters(ext_param_); 00679 stop.set_external_parameters(ext_param_); 00680 00681 00682 //Create poisson samples 00683 PoissonSampler<RandomTT800> poisson_sampler(1.0,vigra::Int32(new_start_index),vigra::Int32(ext_param().row_count_)); 00684 00685 //TODO: visitors for online learning 00686 //visitor.visit_at_beginning(*this, preprocessor); 00687 00688 // THE MAIN EFFING RF LOOP - YEAY DUDE! 00689 for(int ii = 0; ii < (int)trees_.size(); ++ii) 00690 { 00691 online_visitor_.tree_id=ii; 00692 poisson_sampler.sample(); 00693 std::map<int,int> leaf_parents; 00694 leaf_parents.clear(); 00695 //Get all the leaf nodes for that sample 00696 for(int s=0;s<poisson_sampler.numOfSamples();++s) 00697 { 00698 int sample=poisson_sampler[s]; 00699 online_visitor_.current_label=preprocessor.response()(sample,0); 00700 online_visitor_.last_node_id=StackEntry_t::DecisionTreeNoParent; 00701 int leaf=trees_[ii].getToLeaf(rowVector(features,sample),online_visitor_); 00702 00703 00704 //Add to the list for that leaf 00705 online_visitor_.add_to_index_list(ii,leaf,sample); 00706 //TODO: Class count? 00707 //Store parent 00708 if(Node<e_ConstProbNode>(trees_[ii].topology_,trees_[ii].parameters_,leaf).prob_begin()[preprocessor.response()(sample,0)]!=1.0) 00709 { 00710 leaf_parents[leaf]=online_visitor_.last_node_id; 00711 } 00712 } 00713 00714 00715 std::map<int,int>::iterator leaf_iterator; 00716 for(leaf_iterator=leaf_parents.begin();leaf_iterator!=leaf_parents.end();++leaf_iterator) 00717 { 00718 int leaf=leaf_iterator->first; 00719 int parent=leaf_iterator->second; 00720 int lin_index=online_visitor_.trees_online_information[ii].exterior_to_index[leaf]; 00721 ArrayVector<Int32> indeces; 00722 indeces.clear(); 00723 indeces.swap(online_visitor_.trees_online_information[ii].index_lists[lin_index]); 00724 StackEntry_t stack_entry(indeces.begin(), 00725 indeces.end(), 00726 ext_param_.class_count_); 00727 00728 00729 if(parent!=-1) 00730 { 00731 if(NodeBase(trees_[ii].topology_,trees_[ii].parameters_,parent).child(0)==leaf) 00732 { 00733 stack_entry.leftParent=parent; 00734 } 00735 else 00736 { 00737 vigra_assert(NodeBase(trees_[ii].topology_,trees_[ii].parameters_,parent).child(1)==leaf,"last_node_id seems to be wrong"); 00738 stack_entry.rightParent=parent; 00739 } 00740 } 00741 //trees_[ii].continueLearn(preprocessor.features(),preprocessor.response(),stack_entry,split,stop,visitor,randint,leaf); 00742 trees_[ii].continueLearn(preprocessor.features(),preprocessor.response(),stack_entry,split,stop,visitor,randint,-1); 00743 //Now, the last one moved onto leaf 00744 online_visitor_.move_exterior_node(ii,trees_[ii].topology_.size(),ii,leaf); 00745 //Now it should be classified correctly! 00746 } 00747 00748 /*visitor 00749 .visit_after_tree( *this, 00750 preprocessor, 00751 poisson_sampler, 00752 stack_entry, 00753 ii);*/ 00754 } 00755 00756 //visitor.visit_at_end(*this, preprocessor); 00757 online_visitor_.deactivate(); 00758 00759 return visitor.return_val(); 00760 } 00761 00762 template<class LabelType, class PreprocessorTag> 00763 template<class U,class C1, 00764 class U2, class C2, 00765 class Split_t, 00766 class Stop_t, 00767 class Visitor_t, 00768 class Random_t> 00769 void RandomForest<LabelType, PreprocessorTag>::reLearnTree(MultiArrayView<2,U,C1> const & features, 00770 MultiArrayView<2,U2,C2> const & response, 00771 int treeId, 00772 Visitor_t visitor_, 00773 Split_t split_, 00774 Stop_t stop_, 00775 Random_t & random) 00776 { 00777 using namespace rf; 00778 //We store as a local variable, beacause there is no global interest ?!? 00779 typedef typename Split_t::StackEntry_t StackEntry_t; 00780 typedef UniformIntRandomFunctor<Random_t> 00781 RandFunctor_t; 00782 00783 // See rf_preprocessing.hxx for more info on this 00784 ext_param_.class_count_=0; 00785 typedef Processor<PreprocessorTag,LabelType, U, C1, U2, C2> Preprocessor_t; 00786 00787 // default values and initialization 00788 // Value Chooser chooses second argument as value if first argument 00789 // is of type RF_DEFAULT. (thanks to template magic - don't care about 00790 // it - just smile and wave. 00791 00792 #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_> 00793 Default_Stop_t default_stop(options_); 00794 typename RF_CHOOSER(Stop_t)::type stop 00795 = RF_CHOOSER(Stop_t)::choose(stop_, default_stop); 00796 Default_Split_t default_split; 00797 typename RF_CHOOSER(Split_t)::type split 00798 = RF_CHOOSER(Split_t)::choose(split_, default_split); 00799 StopVisiting stopvisiting; 00800 OOB_Visitor oob; 00801 typedef VisitorNode<OnlineLearnVisitor, typename RF_CHOOSER(Visitor_t)::type> IntermedVis; 00802 IntermedVis 00803 inter(online_visitor_, RF_CHOOSER(Visitor_t)::choose(visitor_, stopvisiting)); 00804 VisitorNode<OOB_Visitor, IntermedVis> 00805 visitor(oob,inter); 00806 #undef RF_CHOOSER 00807 vigra_precondition(options_.prepare_online_learning_,"reLearnTree: Re learning trees only makes sense, if online learning is enabled"); 00808 online_visitor_.activate(); 00809 00810 // Make stl compatible random functor. 00811 RandFunctor_t randint ( random); 00812 00813 // Preprocess the data to get something the split functor can work 00814 // with. Also fill the ext_param structure by preprocessing 00815 // option parameters that could only be completely evaluated 00816 // when the training data is known. 00817 Preprocessor_t preprocessor( features, response, 00818 options_, ext_param_); 00819 00820 // Give the Split functor information about the data. 00821 split.set_external_parameters(ext_param_); 00822 stop.set_external_parameters(ext_param_); 00823 00824 /**\todo replace this crappy class out. It uses function pointers. 00825 * and is making code slower according to me. 00826 * Comment from Nathan: This is copied from Rahul, so me=Rahul 00827 */ 00828 Sampler<RandFunctor_t > sampler(ext_param().row_count_, 00829 ext_param().actual_msample_, 00830 detail::make_sampler_opt(options_, 00831 preprocessor.strata()), 00832 randint); 00833 00834 //initialize First region/node/stack entry 00835 sampler 00836 .sample(); 00837 00838 StackEntry_t 00839 first_stack_entry( sampler.used_indices().begin(), 00840 sampler.used_indices().end(), 00841 ext_param_.class_count_); 00842 first_stack_entry 00843 .set_oob_range( sampler.unused_indices().begin(), 00844 sampler.unused_indices().end()); 00845 online_visitor_.reset_tree(treeId); 00846 online_visitor_.tree_id=treeId; 00847 trees_[treeId].reset(); 00848 trees_[treeId] 00849 .learn( preprocessor.features(), 00850 preprocessor.response(), 00851 first_stack_entry, 00852 split, 00853 stop, 00854 visitor, 00855 randint); 00856 visitor 00857 .visit_after_tree( *this, 00858 preprocessor, 00859 sampler, 00860 first_stack_entry, 00861 treeId); 00862 00863 online_visitor_.deactivate(); 00864 } 00865 00866 template <class LabelType, class PreprocessorTag> 00867 template <class U, class C1, 00868 class U2,class C2, 00869 class Split_t, 00870 class Stop_t, 00871 class Visitor_t, 00872 class Random_t> 00873 double RandomForest<LabelType, PreprocessorTag>:: 00874 learn( MultiArrayView<2, U, C1> const & features, 00875 MultiArrayView<2, U2,C2> const & response, 00876 Visitor_t visitor_, 00877 Split_t split_, 00878 Stop_t stop_, 00879 Random_t const & random) 00880 { 00881 using namespace rf; 00882 //this->reset(); 00883 //typedefs 00884 typedef typename Split_t::StackEntry_t StackEntry_t; 00885 typedef UniformIntRandomFunctor<Random_t> 00886 RandFunctor_t; 00887 00888 // See rf_preprocessing.hxx for more info on this 00889 typedef Processor<PreprocessorTag,LabelType, U, C1, U2, C2> Preprocessor_t; 00890 00891 // default values and initialization 00892 // Value Chooser chooses second argument as value if first argument 00893 // is of type RF_DEFAULT. (thanks to template magic - don't care about 00894 // it - just smile and wave. 00895 00896 #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_> 00897 Default_Stop_t default_stop(options_); 00898 typename RF_CHOOSER(Stop_t)::type stop 00899 = RF_CHOOSER(Stop_t)::choose(stop_, default_stop); 00900 Default_Split_t default_split; 00901 typename RF_CHOOSER(Split_t)::type split 00902 = RF_CHOOSER(Split_t)::choose(split_, default_split); 00903 StopVisiting stopvisiting; 00904 OOB_Visitor oob; 00905 typedef VisitorNode<OnlineLearnVisitor, typename RF_CHOOSER(Visitor_t)::type> IntermedVis; 00906 IntermedVis 00907 inter(online_visitor_, RF_CHOOSER(Visitor_t)::choose(visitor_, stopvisiting)); 00908 VisitorNode<OOB_Visitor, IntermedVis> 00909 visitor(oob,inter); 00910 #undef RF_CHOOSER 00911 if(options_.prepare_online_learning_) 00912 online_visitor_.activate(); 00913 else 00914 online_visitor_.deactivate(); 00915 00916 00917 // Make stl compatible random functor. 00918 RandFunctor_t randint ( random); 00919 00920 00921 // Preprocess the data to get something the split functor can work 00922 // with. Also fill the ext_param structure by preprocessing 00923 // option parameters that could only be completely evaluated 00924 // when the training data is known. 00925 Preprocessor_t preprocessor( features, response, 00926 options_, ext_param_); 00927 00928 // Give the Split functor information about the data. 00929 split.set_external_parameters(ext_param_); 00930 stop.set_external_parameters(ext_param_); 00931 00932 00933 //initialize trees. 00934 trees_.resize(options_.tree_count_ , DecisionTree_t(ext_param_)); 00935 00936 /**\todo replace this crappy class out. It uses function pointers. 00937 * and is making code slower according to me 00938 */ 00939 Sampler<RandFunctor_t > sampler(ext_param().actual_msample_, 00940 ext_param().row_count_, 00941 detail::make_sampler_opt(options_, 00942 preprocessor.strata()), 00943 randint); 00944 00945 visitor.visit_at_beginning(*this, preprocessor); 00946 // THE MAIN EFFING RF LOOP - YEAY DUDE! 00947 00948 for(int ii = 0; ii < (int)trees_.size(); ++ii) 00949 { 00950 //initialize First region/node/stack entry 00951 sampler 00952 .sample(); 00953 StackEntry_t 00954 first_stack_entry( sampler.used_indices().begin(), 00955 sampler.used_indices().end(), 00956 ext_param_.class_count_); 00957 first_stack_entry 00958 .set_oob_range( sampler.unused_indices().begin(), 00959 sampler.unused_indices().end()); 00960 trees_[ii] 00961 .learn( preprocessor.features(), 00962 preprocessor.response(), 00963 first_stack_entry, 00964 split, 00965 stop, 00966 visitor, 00967 randint); 00968 visitor 00969 .visit_after_tree( *this, 00970 preprocessor, 00971 sampler, 00972 first_stack_entry, 00973 ii); 00974 } 00975 00976 visitor.visit_at_end(*this, preprocessor); 00977 online_visitor_.deactivate(); 00978 00979 return visitor.return_val(); 00980 } 00981 00982 00983 00984 00985 template <class LabelType, class Tag> 00986 template <class U, class C, class Stop> 00987 LabelType RandomForest<LabelType, Tag> 00988 ::predictLabel(MultiArrayView<2, U, C> const & features, Stop & stop) const 00989 { 00990 vigra_precondition(columnCount(features) >= ext_param_.column_count_, 00991 "RandomForestn::predictLabel():" 00992 " Too few columns in feature matrix."); 00993 vigra_precondition(rowCount(features) == 1, 00994 "RandomForestn::predictLabel():" 00995 " Feature matrix must have a singlerow."); 00996 typedef MultiArrayShape<2>::type Shp; 00997 garbage_prediction_.reshape(Shp(1, ext_param_.class_count_), 0.0); 00998 LabelType d; 00999 predictProbabilities(features, garbage_prediction_, stop); 01000 ext_param_.to_classlabel(argMax(garbage_prediction_), d); 01001 return d; 01002 } 01003 01004 01005 //Same thing as above with priors for each label !!! 01006 template <class LabelType, class PreprocessorTag> 01007 template <class U, class C> 01008 LabelType RandomForest<LabelType, PreprocessorTag> 01009 ::predictLabel( MultiArrayView<2, U, C> const & features, 01010 ArrayVectorView<double> priors) const 01011 { 01012 using namespace functor; 01013 vigra_precondition(columnCount(features) >= ext_param_.column_count_, 01014 "RandomForestn::predictLabel(): Too few columns in feature matrix."); 01015 vigra_precondition(rowCount(features) == 1, 01016 "RandomForestn::predictLabel():" 01017 " Feature matrix must have a single row."); 01018 Matrix<double> prob(1,ext_param_.class_count_); 01019 predictProbabilities(features, prob); 01020 std::transform( prob.begin(), prob.end(), 01021 priors.begin(), prob.begin(), 01022 Arg1()*Arg2()); 01023 LabelType d; 01024 ext_param_.to_classlabel(argMax(prob), d); 01025 return d; 01026 } 01027 01028 template<class LabelType,class PreprocessorTag> 01029 template <class T1,class T2, class C> 01030 void RandomForest<LabelType,PreprocessorTag> 01031 ::predictProbabilities(OnlinePredictionSet<T1> & predictionSet, 01032 MultiArrayView<2, T2, C> & prob) 01033 { 01034 //Features are n xp 01035 //prob is n x NumOfLabel probaility for each feature in each class 01036 01037 vigra_precondition(rowCount(predictionSet.features) == rowCount(prob), 01038 "RandomFroest::predictProbabilities():" 01039 " Feature matrix and probability matrix size misnmatch."); 01040 // num of features must be bigger than num of features in Random forest training 01041 // but why bigger? 01042 vigra_precondition( columnCount(predictionSet.features) >= ext_param_.column_count_, 01043 "RandomForestn::predictProbabilities():" 01044 " Too few columns in feature matrix."); 01045 vigra_precondition( columnCount(prob) 01046 == (MultiArrayIndex)ext_param_.class_count_, 01047 "RandomForestn::predictProbabilities():" 01048 " Probability matrix must have as many columns as there are classes."); 01049 prob.init(0.0); 01050 //store total weights 01051 std::vector<T1> totalWeights(predictionSet.indices[0].size(),0.0); 01052 //Go through all trees 01053 int set_id=-1; 01054 for(int k=0; k<options_.tree_count_; ++k) 01055 { 01056 set_id=(set_id+1) % predictionSet.indices[0].size(); 01057 typedef std::set<SampleRange<T1> > my_set; 01058 typedef typename my_set::iterator set_it; 01059 //typedef std::set<std::pair<int,SampleRange<T1> > >::iterator set_it; 01060 //Build a stack with all the ranges we have 01061 std::vector<std::pair<int,set_it> > stack; 01062 stack.clear(); 01063 set_it i; 01064 for(i=predictionSet.ranges[set_id].begin();i!=predictionSet.ranges[set_id].end();++i) 01065 stack.push_back(std::pair<int,set_it>(2,i)); 01066 //get weights predicted by single tree 01067 int num_decisions=0; 01068 while(!stack.empty()) 01069 { 01070 set_it range=stack.back().second; 01071 int index=stack.back().first; 01072 stack.pop_back(); 01073 ++num_decisions; 01074 01075 if(trees_[k].isLeafNode(trees_[k].topology_[index])) 01076 { 01077 ArrayVector<double>::iterator weights=Node<e_ConstProbNode>(trees_[k].topology_, 01078 trees_[k].parameters_, 01079 index).prob_begin(); 01080 for(int i=range->start;i!=range->end;++i) 01081 { 01082 //update votecount. 01083 for(int l=0; l<ext_param_.class_count_; ++l) 01084 { 01085 prob(predictionSet.indices[set_id][i], l) += (T2)weights[l]; 01086 //every weight in totalWeight. 01087 totalWeights[predictionSet.indices[set_id][i]] += (T1)weights[l]; 01088 } 01089 } 01090 } 01091 01092 else 01093 { 01094 if(trees_[k].topology_[index]!=i_ThresholdNode) 01095 { 01096 throw std::runtime_error("predicting with online prediction sets is only supported for RFs with threshold nodes"); 01097 } 01098 Node<i_ThresholdNode> node(trees_[k].topology_,trees_[k].parameters_,index); 01099 if(range->min_boundaries[node.column()]>=node.threshold()) 01100 { 01101 //Everything goes to right child 01102 stack.push_back(std::pair<int,set_it>(node.child(1),range)); 01103 continue; 01104 } 01105 if(range->max_boundaries[node.column()]<node.threshold()) 01106 { 01107 //Everything goes to the left child 01108 stack.push_back(std::pair<int,set_it>(node.child(0),range)); 01109 continue; 01110 } 01111 //We have to split at this node 01112 SampleRange<T1> new_range=*range; 01113 new_range.min_boundaries[node.column()]=FLT_MAX; 01114 range->max_boundaries[node.column()]=-FLT_MAX; 01115 new_range.start=new_range.end=range->end; 01116 int i=range->start; 01117 while(i!=range->end) 01118 { 01119 //Decide for range->indices[i] 01120 if(predictionSet.features(predictionSet.indices[set_id][i],node.column())>=node.threshold()) 01121 { 01122 new_range.min_boundaries[node.column()]=std::min(new_range.min_boundaries[node.column()], 01123 predictionSet.features(predictionSet.indices[set_id][i],node.column())); 01124 --range->end; 01125 --new_range.start; 01126 std::swap(predictionSet.indices[set_id][i],predictionSet.indices[set_id][range->end]); 01127 01128 } 01129 else 01130 { 01131 range->max_boundaries[node.column()]=std::max(range->max_boundaries[node.column()], 01132 predictionSet.features(predictionSet.indices[set_id][i],node.column())); 01133 ++i; 01134 } 01135 } 01136 //The old one ... 01137 if(range->start==range->end) 01138 { 01139 predictionSet.ranges[set_id].erase(range); 01140 } 01141 else 01142 { 01143 stack.push_back(std::pair<int,set_it>(node.child(0),range)); 01144 } 01145 //And the new one ... 01146 if(new_range.start!=new_range.end) 01147 { 01148 std::pair<set_it,bool> new_it=predictionSet.ranges[set_id].insert(new_range); 01149 stack.push_back(std::pair<int,set_it>(node.child(1),new_it.first)); 01150 } 01151 } 01152 } 01153 predictionSet.cumulativePredTime[k]=num_decisions; 01154 } 01155 for(unsigned int i=0;i<totalWeights.size();++i) 01156 { 01157 double test=0.0; 01158 //Normalise votes in each row by total VoteCount (totalWeight 01159 for(int l=0; l<ext_param_.class_count_; ++l) 01160 { 01161 test+=prob(i,l); 01162 prob(i, l) /= totalWeights[i]; 01163 } 01164 assert(test==totalWeights[i]); 01165 assert(totalWeights[i]>0.0); 01166 } 01167 } 01168 01169 template <class LabelType, class PreprocessorTag> 01170 template <class U, class C1, class T, class C2, class Stop_t> 01171 void RandomForest<LabelType, PreprocessorTag> 01172 ::predictProbabilities(MultiArrayView<2, U, C1>const & features, 01173 MultiArrayView<2, T, C2> & prob, 01174 Stop_t & stop_) const 01175 { 01176 //Features are n xp 01177 //prob is n x NumOfLabel probability for each feature in each class 01178 01179 vigra_precondition(rowCount(features) == rowCount(prob), 01180 "RandomForestn::predictProbabilities():" 01181 " Feature matrix and probability matrix size mismatch."); 01182 01183 // num of features must be bigger than num of features in Random forest training 01184 // but why bigger? 01185 vigra_precondition( columnCount(features) >= ext_param_.column_count_, 01186 "RandomForestn::predictProbabilities():" 01187 " Too few columns in feature matrix."); 01188 vigra_precondition( columnCount(prob) 01189 == (MultiArrayIndex)ext_param_.class_count_, 01190 "RandomForestn::predictProbabilities():" 01191 " Probability matrix must have as many columns as there are classes."); 01192 01193 #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_> 01194 Default_Stop_t default_stop(options_); 01195 typename RF_CHOOSER(Stop_t)::type & stop 01196 = RF_CHOOSER(Stop_t)::choose(stop_, default_stop); 01197 #undef RF_CHOOSER 01198 stop.set_external_parameters(ext_param_, tree_count()); 01199 prob.init(NumericTraits<T>::zero()); 01200 if(tree_indices_.size() != 0) 01201 { 01202 std::random_shuffle(tree_indices_.begin(), 01203 tree_indices_.end()); 01204 } 01205 //Classify for each row. 01206 for(int row=0; row < rowCount(features); ++row) 01207 { 01208 ArrayVector<double>::const_iterator weights; 01209 01210 //totalWeight == totalVoteCount! 01211 double totalWeight = 0.0; 01212 01213 //Let each tree classify... 01214 for(int k=0; k<options_.tree_count_; ++k) 01215 { 01216 //get weights predicted by single tree 01217 weights = trees_[tree_indices_[k]].predict(rowVector(features, row)); 01218 01219 //update votecount. 01220 int weighted = options_.predict_weighted_; 01221 for(int l=0; l<ext_param_.class_count_; ++l) 01222 { 01223 double cur_w = weights[l] * (weighted * (*(weights-1)) 01224 + (1-weighted)); 01225 prob(row, l) += (T)cur_w; 01226 //every weight in totalWeight. 01227 totalWeight += cur_w; 01228 } 01229 if(stop.after_prediction(weights, 01230 k, 01231 rowVector(prob, row), 01232 totalWeight)) 01233 { 01234 break; 01235 } 01236 } 01237 01238 //Normalise votes in each row by total VoteCount (totalWeight 01239 for(int l=0; l< ext_param_.class_count_; ++l) 01240 { 01241 prob(row, l) /= detail::RequiresExplicitCast<T>::cast(totalWeight); 01242 } 01243 } 01244 01245 } 01246 01247 //@} 01248 01249 } // namespace vigra 01250 01251 #endif // VIGRA_RANDOM_FOREST_HXX
© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de) |
html generated using doxygen and Python
|