[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]

vigra/random_forest.hxx
00001 /************************************************************************/
00002 /*                                                                      */
00003 /*        Copyright 2008-2009 by  Ullrich Koethe and Rahul Nair         */
00004 /*                                                                      */
00005 /*    This file is part of the VIGRA computer vision library.           */
00006 /*    The VIGRA Website is                                              */
00007 /*        http://hci.iwr.uni-heidelberg.de/vigra/                       */
00008 /*    Please direct questions, bug reports, and contributions to        */
00009 /*        ullrich.koethe@iwr.uni-heidelberg.de    or                    */
00010 /*        vigra@informatik.uni-hamburg.de                               */
00011 /*                                                                      */
00012 /*    Permission is hereby granted, free of charge, to any person       */
00013 /*    obtaining a copy of this software and associated documentation    */
00014 /*    files (the "Software"), to deal in the Software without           */
00015 /*    restriction, including without limitation the rights to use,      */
00016 /*    copy, modify, merge, publish, distribute, sublicense, and/or      */
00017 /*    sell copies of the Software, and to permit persons to whom the    */
00018 /*    Software is furnished to do so, subject to the following          */
00019 /*    conditions:                                                       */
00020 /*                                                                      */
00021 /*    The above copyright notice and this permission notice shall be    */
00022 /*    included in all copies or substantial portions of the             */
00023 /*    Software.                                                         */
00024 /*                                                                      */
00025 /*    THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND    */
00026 /*    EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES   */
00027 /*    OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND          */
00028 /*    NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT       */
00029 /*    HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,      */
00030 /*    WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING      */
00031 /*    FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR     */
00032 /*    OTHER DEALINGS IN THE SOFTWARE.                                   */
00033 /*                                                                      */
00034 /************************************************************************/
00035 
00036 
00037 #ifndef VIGRA_RANDOM_FOREST_HXX
00038 #define VIGRA_RANDOM_FOREST_HXX
00039 
00040 #include <iostream>
00041 #include <algorithm>
00042 #include <map>
00043 #include <set>
00044 #include <list>
00045 #include <numeric>
00046 #include "mathutil.hxx"
00047 #include "array_vector.hxx"
00048 #include "sized_int.hxx"
00049 #include "matrix.hxx"
00050 #include "random.hxx"
00051 #include "functorexpression.hxx"
00052 #include "random_forest/rf_common.hxx"
00053 #include "random_forest/rf_nodeproxy.hxx"
00054 #include "random_forest/rf_split.hxx"
00055 #include "random_forest/rf_decisionTree.hxx"
00056 #include "random_forest/rf_visitors.hxx"
00057 #include "random_forest/rf_region.hxx"
00058 #include "sampling.hxx"
00059 #include "random_forest/rf_preprocessing.hxx"
00060 #include "random_forest/rf_online_prediction_set.hxx"
00061 #include "random_forest/rf_earlystopping.hxx"
00062 #include "random_forest/rf_ridge_split.hxx"
00063 namespace vigra
00064 {
00065 
00066 /** \addtogroup MachineLearning Machine Learning
00067 
00068     This module provides classification algorithms that map 
00069     features to labels or label probablities.
00070     Look at the RandomForest class first for a overview of most of the 
00071     functionality provided as well as use cases. 
00072 **/
00073 //@{
00074 
00075 namespace detail
00076 {
00077 
00078 
00079 
00080 /* \brief sampling option factory function
00081  */
00082 inline SamplerOptions make_sampler_opt ( RandomForestOptions     & RF_opt)
00083 {
00084     SamplerOptions return_opt;
00085     return_opt.withReplacement(RF_opt.sample_with_replacement_);
00086     return_opt.stratified(RF_opt.stratification_method_ == RF_EQUAL);
00087     return return_opt;
00088 }
00089 }//namespace detail
00090 
00091 /** Random Forest class
00092  *
00093  * \tparam <PrprocessorTag = ClassificationTag> Class used to preprocess
00094  *          the input while learning and predicting. Currently Available:
00095  *          ClassificationTag and RegressionTag. It is recommended to use
00096  *          Splitfunctor::Preprocessor_t while using custom splitfunctors
00097  *          as they may need the data to be in a different format. 
00098  *          \sa Preprocessor
00099  *  
00100  *  simple usage for classification (regression is not yet supported):
00101  *  look at RandomForest::learn() as well as RandomForestOptions() for additional
00102  *  options. 
00103  *
00104  *  \code
00105  *  typedef xxx feature_t \\ replace xxx with whichever type
00106  *  typedef yyy label_t   \\ meme chose. 
00107  *  MultiArrayView<2, feature_t> f = get_some_features();
00108  *  MultiArrayView<2, label_t>   l = get_some_labels();
00109  *  RandomForest<> rf()
00110  *  double oob_error = rf.learn(f, l);
00111  *      
00112  *  MultiArrayView<2, feature_t> pf = get_some_unknown_features();
00113  *  MultiArrayView<2, label_t> prediction = allocate_space_for_response();
00114  *  MultiArrayView<2, double> prob  = allocate_space_for_probability();
00115  *      
00116  *  rf.predict_labels(pf, prediction);
00117  *  rf.predict_probabilities(pf, prob);
00118  *
00119  *  \endcode
00120  *
00121  *  Additional information such as OOB Error and Variable Importance measures are accessed
00122  *  via Visitors defined in rf::visitors. 
00123  *  Have a look at rf::split for other splitting methods.
00124  *
00125 */
00126 template <class LabelType = double , class PreprocessorTag = ClassificationTag >
00127 class RandomForest
00128 {
00129 
00130   public:
00131     //public typedefs
00132     typedef RandomForestOptions             Options_t;
00133     typedef detail::DecisionTree            DecisionTree_t;
00134     typedef ProblemSpec<LabelType>          ProblemSpec_t;
00135     typedef GiniSplit                       Default_Split_t;
00136     typedef EarlyStoppStd                   Default_Stop_t;
00137     typedef rf::visitors::StopVisiting      Default_Visitor_t;
00138     typedef  DT_StackEntry<ArrayVectorView<Int32>::iterator>
00139                     StackEntry_t;
00140     typedef LabelType                       LabelT; 
00141   protected:
00142 
00143     /** optimisation for predictLabels
00144      * */
00145     mutable MultiArray<2, double> garbage_prediction_;
00146 
00147   public:
00148 
00149     //problem independent data.
00150     Options_t                                   options_;
00151     //problem dependent data members - is only set if
00152     //a copy constructor, some sort of import
00153     //function or the learn function is called
00154     ArrayVector<DecisionTree_t>                 trees_;
00155     ProblemSpec_t                               ext_param_;
00156     /*mutable ArrayVector<int>                    tree_indices_;*/
00157     rf::visitors::OnlineLearnVisitor            online_visitor_;
00158 
00159 
00160     void reset()
00161     {
00162         ext_param_.clear();
00163         trees_.clear();
00164     }
00165 
00166   public:
00167 
00168     /** \name Contructors
00169      * Note: No copy Constructor specified as no pointers are manipulated
00170      * in this class
00171      */
00172     /*\{*/
00173     /**\brief default constructor
00174      *
00175      * \param options   general options to the Random Forest. Must be of Type
00176      *                  Options_t
00177      * \param ext_param problem specific values that can be supplied 
00178      *                  additionally. (class weights , labels etc)
00179      * \sa  RandomForestOptions, ProblemSpec
00180      *
00181      */
00182     RandomForest(Options_t const & options = Options_t(), 
00183                  ProblemSpec_t const & ext_param = ProblemSpec_t())
00184     :
00185         options_(options),
00186         ext_param_(ext_param)/*,
00187         tree_indices_(options.tree_count_,0)*/
00188     {
00189         /*for(int ii = 0 ; ii < int(tree_indices_.size()); ++ii)
00190             tree_indices_[ii] = ii;*/
00191     }
00192 
00193     /**\brief Create RF from external source
00194      * \param treeCount Number of trees to add.
00195      * \param topology_begin     
00196      *                  Iterator to a Container where the topology_ data
00197      *                  of the trees are stored.
00198      *                  Iterator should support at least treeCount forward 
00199      *                  iterations. (i.e. topology_end - topology_begin >= treeCount
00200      * \param parameter_begin  
00201      *                  iterator to a Container where the parameters_ data
00202      *                  of the trees are stored. Iterator should support at 
00203      *                  least treeCount forward iterations.
00204      * \param problem_spec 
00205      *                  Extrinsic parameters that specify the problem e.g.
00206      *                  ClassCount, featureCount etc.
00207      * \param options   (optional) specify options used to train the original
00208      *                  Random forest. This parameter is not used anywhere
00209      *                  during prediction and thus is optional.
00210      *
00211      */
00212      /* TODO: This constructor may be replaced by a Constructor using
00213      * NodeProxy iterators to encapsulate the underlying data type.
00214      */
00215     template<class TopologyIterator, class ParameterIterator>
00216     RandomForest(int                       treeCount,
00217                   TopologyIterator         topology_begin,
00218                   ParameterIterator        parameter_begin,
00219                   ProblemSpec_t const & problem_spec,
00220                   Options_t const &     options = Options_t())
00221     :
00222         trees_(treeCount, DecisionTree_t(problem_spec)),
00223         ext_param_(problem_spec),
00224         options_(options)
00225     {
00226         for(unsigned int k=0; k<treeCount; ++k, ++topology_begin, ++parameter_begin)
00227         {
00228             trees_[k].topology_ = *topology_begin;
00229             trees_[k].parameters_ = *parameter_begin;
00230         }
00231     }
00232 
00233     /*\}*/
00234 
00235 
00236     /** \name Data Access
00237      * data access interface - usage of member variables is deprecated
00238      */
00239 
00240     /*\{*/
00241 
00242 
00243     /**\brief return external parameters for viewing
00244      * \return ProblemSpec_t
00245      */
00246     ProblemSpec_t const & ext_param() const
00247     {
00248         vigra_precondition(ext_param_.used() == true,
00249            "RandomForest::ext_param(): "
00250            "Random forest has not been trained yet.");
00251         return ext_param_;
00252     }
00253 
00254     /**\brief set external parameters
00255      *
00256      *  \param in external parameters to be set
00257      *
00258      * set external parameters explicitly. 
00259      * If Random Forest has not been trained the preprocessor will 
00260      * either ignore filling values set this way or will throw an exception 
00261      * if values specified manually do not match the value calculated 
00262      & during the preparation step.
00263      */
00264     void set_ext_param(ProblemSpec_t const & in)
00265     {
00266         vigra_precondition(ext_param_.used() == false,
00267             "RandomForest::set_ext_param():"
00268             "Random forest has been trained! Call reset()"
00269             "before specifying new extrinsic parameters.");
00270     }
00271 
00272     /**\brief access random forest options
00273      *
00274      * \return random forest options
00275      */
00276     Options_t & set_options()
00277     {
00278         return options;
00279     }
00280 
00281 
00282     /**\brief access const random forest options
00283      *
00284      * \return const Option_t
00285      */
00286     Options_t const & options() const
00287     {
00288         return options_;
00289     }
00290 
00291     /**\brief access const trees
00292      */
00293     DecisionTree_t const & tree(int index) const
00294     {
00295         return trees_[index];
00296     }
00297 
00298     /**\brief access trees
00299      */
00300     DecisionTree_t & tree(int index)
00301     {
00302         return trees_[index];
00303     }
00304 
00305     /*\}*/
00306 
00307     /**\brief return number of features used while 
00308      * training.
00309      */
00310     int feature_count() const
00311     {
00312       return ext_param_.column_count_;
00313     }
00314     
00315     
00316     /**\brief return number of features used while 
00317      * training.
00318      *
00319      * deprecated. Use feature_count() instead.
00320      */
00321     int column_count() const
00322     {
00323       return ext_param_.column_count_;
00324     }
00325 
00326     /**\brief return number of classes used while 
00327      * training.
00328      */
00329     int class_count() const
00330     {
00331       return ext_param_.class_count_;
00332     }
00333 
00334     /**\brief return number of trees
00335      */
00336     int tree_count() const
00337     {
00338       return options_.tree_count_;
00339     }
00340 
00341 
00342     
00343     template<class U,class C1,
00344         class U2, class C2,
00345         class Split_t,
00346         class Stop_t,
00347         class Visitor_t,
00348         class Random_t>
00349     void onlineLearn(   MultiArrayView<2,U,C1> const & features,
00350                         MultiArrayView<2,U2,C2> const & response,
00351                         int new_start_index,
00352                         Visitor_t visitor_,
00353                         Split_t split_,
00354                         Stop_t stop_,
00355                         Random_t & random,
00356                         bool adjust_thresholds=false);
00357 
00358     template <class U, class C1, class U2,class C2>
00359     void onlineLearn(   MultiArrayView<2, U, C1> const  & features,
00360                         MultiArrayView<2, U2,C2> const  & labels,int new_start_index,bool adjust_thresholds=false)
00361     {
00362         RandomNumberGenerator<> rnd = RandomNumberGenerator<>(RandomSeed);
00363         onlineLearn(features, 
00364                     labels, 
00365                     new_start_index,
00366                     rf_default(), 
00367                     rf_default(), 
00368                     rf_default(),
00369                     rnd,
00370                     adjust_thresholds);
00371     }
00372 
00373     template<class U,class C1,
00374         class U2, class C2,
00375         class Split_t,
00376         class Stop_t,
00377         class Visitor_t,
00378         class Random_t>
00379     void reLearnTree(MultiArrayView<2,U,C1> const & features,
00380                      MultiArrayView<2,U2,C2> const & response,
00381                      int treeId,
00382                      Visitor_t visitor_,
00383                      Split_t split_,
00384                      Stop_t stop_,
00385                      Random_t & random);
00386 
00387     template<class U, class C1, class U2, class C2>
00388     void reLearnTree(MultiArrayView<2, U, C1> const & features,
00389                      MultiArrayView<2, U2, C2> const & labels,
00390                      int treeId)
00391     {
00392         RandomNumberGenerator<> rnd = RandomNumberGenerator<>(RandomSeed);
00393         reLearnTree(features,
00394                     labels,
00395                     treeId,
00396                     rf_default(),
00397                     rf_default(),
00398                     rf_default(),
00399                     rnd);
00400     }
00401 
00402 
00403     /**\name Learning
00404      * Following functions differ in the degree of customization
00405      * allowed
00406      */
00407     /*\{*/
00408     /**\brief learn on data with custom config and random number generator
00409      *
00410      * \param features  a N x M matrix containing N samples with M
00411      *                  features
00412      * \param response  a N x D matrix containing the corresponding
00413      *                  response. Current split functors assume D to
00414      *                  be 1 and ignore any additional columns.
00415      *                  This is not enforced to allow future support
00416      *                  for uncertain labels, label independent strata etc.
00417      *                  The Preprocessor specified during construction
00418      *                  should be able to handle features and labels
00419      *                  features and the labels.
00420      *                  see also: SplitFunctor, Preprocessing
00421      *
00422      * \param visitor   visitor which is to be applied after each split,
00423      *                  tree and at the end. Use rf_default for using
00424      *                  default value. (No Visitors)
00425      *                  see also: rf::visitors
00426      * \param split     split functor to be used to calculate each split
00427      *                  use rf_default() for using default value. (GiniSplit)
00428      *                  see also:  rf::split 
00429      * \param stop
00430      *                  predicate to be used to calculate each split
00431      *                  use rf_default() for using default value. (EarlyStoppStd)
00432      * \param random    RandomNumberGenerator to be used. Use
00433      *                  rf_default() to use default value.(RandomMT19337)
00434      *
00435      *
00436      */
00437     template <class U, class C1,
00438              class U2,class C2,
00439              class Split_t,
00440              class Stop_t,
00441              class Visitor_t,
00442              class Random_t>
00443     void learn( MultiArrayView<2, U, C1> const  &   features,
00444                 MultiArrayView<2, U2,C2> const  &   response,
00445                 Visitor_t                           visitor,
00446                 Split_t                             split,
00447                 Stop_t                              stop,
00448                 Random_t                 const  &   random);
00449 
00450     template <class U, class C1,
00451              class U2,class C2,
00452              class Split_t,
00453              class Stop_t,
00454              class Visitor_t>
00455     void learn( MultiArrayView<2, U, C1> const  &   features,
00456                 MultiArrayView<2, U2,C2> const  &   response,
00457                 Visitor_t                           visitor,
00458                 Split_t                             split,
00459                 Stop_t                              stop)
00460 
00461     {
00462         RandomNumberGenerator<> rnd = RandomNumberGenerator<>(RandomSeed);
00463         learn(  features, 
00464                 response,
00465                 visitor, 
00466                 split, 
00467                 stop,
00468                 rnd);
00469     }
00470 
00471     template <class U, class C1, class U2,class C2, class Visitor_t>
00472     void learn( MultiArrayView<2, U, C1> const  & features,
00473                 MultiArrayView<2, U2,C2> const  & labels,
00474                 Visitor_t                         visitor)
00475     {
00476         learn(  features, 
00477                 labels, 
00478                 visitor, 
00479                 rf_default(), 
00480                 rf_default());
00481     }
00482 
00483     template <class U, class C1, class U2,class C2, 
00484               class Visitor_t, class Split_t>
00485     void learn(   MultiArrayView<2, U, C1> const  & features,
00486                   MultiArrayView<2, U2,C2> const  & labels,
00487                   Visitor_t                         visitor,
00488                   Split_t                           split)
00489     {
00490         learn(  features, 
00491                 labels, 
00492                 visitor, 
00493                 split, 
00494                 rf_default());
00495     }
00496 
00497     /**\brief learn on data with default configuration
00498      *
00499      * \param features  a N x M matrix containing N samples with M
00500      *                  features
00501      * \param labels    a N x D matrix containing the corresponding
00502      *                  N labels. Current split functors assume D to
00503      *                  be 1 and ignore any additional columns.
00504      *                  this is not enforced to allow future support
00505      *                  for uncertain labels.
00506      *
00507      * learning is done with:
00508      *
00509      * \sa rf::split, EarlyStoppStd
00510      *
00511      * - Randomly seeded random number generator
00512      * - default gini split functor as described by Breiman
00513      * - default The standard early stopping criterion
00514      */
00515     template <class U, class C1, class U2,class C2>
00516     void learn(   MultiArrayView<2, U, C1> const  & features,
00517                     MultiArrayView<2, U2,C2> const  & labels)
00518     {
00519         learn(  features, 
00520                 labels, 
00521                 rf_default(), 
00522                 rf_default(), 
00523                 rf_default());
00524     }
00525     /*\}*/
00526 
00527 
00528 
00529     /**\name prediction
00530      */
00531     /*\{*/
00532     /** \brief predict a label given a feature.
00533      *
00534      * \param features: a 1 by featureCount matrix containing
00535      *        data point to be predicted (this only works in
00536      *        classification setting)
00537      * \param stop: early stopping critierion
00538      * \return double value representing class. You can use the
00539      *         predictLabels() function together with the
00540      *         rf.external_parameter().class_type_ attribute
00541      *         to get back the same type used during learning. 
00542      */
00543     template <class U, class C, class Stop>
00544     LabelType predictLabel(MultiArrayView<2, U, C>const & features, Stop & stop) const;
00545 
00546     template <class U, class C>
00547     LabelType predictLabel(MultiArrayView<2, U, C>const & features)
00548     {
00549         return predictLabel(features, rf_default()); 
00550     } 
00551     /** \brief predict a label with features and class priors
00552      *
00553      * \param features: same as above.
00554      * \param prior:   iterator to prior weighting of classes
00555      * \return sam as above.
00556      */
00557     template <class U, class C>
00558     LabelType predictLabel(MultiArrayView<2, U, C> const & features,
00559                                 ArrayVectorView<double> prior) const;
00560 
00561     /** \brief predict multiple labels with given features
00562      *
00563      * \param features: a n by featureCount matrix containing
00564      *        data point to be predicted (this only works in
00565      *        classification setting)
00566      * \param labels: a n by 1 matrix passed by reference to store
00567      *        output.
00568      */
00569     template <class U, class C1, class T, class C2>
00570     void predictLabels(MultiArrayView<2, U, C1>const & features,
00571                        MultiArrayView<2, T, C2> & labels) const
00572     {
00573         vigra_precondition(features.shape(0) == labels.shape(0),
00574             "RandomForest::predictLabels(): Label array has wrong size.");
00575         for(int k=0; k<features.shape(0); ++k)
00576             labels(k,0) = detail::RequiresExplicitCast<T>::cast(predictLabel(rowVector(features, k), rf_default()));
00577     }
00578 
00579     template <class U, class C1, class T, class C2, class Stop>
00580     void predictLabels(MultiArrayView<2, U, C1>const & features,
00581                        MultiArrayView<2, T, C2> & labels,
00582                        Stop                     & stop) const
00583     {
00584         vigra_precondition(features.shape(0) == labels.shape(0),
00585             "RandomForest::predictLabels(): Label array has wrong size.");
00586         for(int k=0; k<features.shape(0); ++k)
00587             labels(k,0) = detail::RequiresExplicitCast<T>::cast(predictLabel(rowVector(features, k), stop));
00588     }
00589     /** \brief predict the class probabilities for multiple labels
00590      *
00591      *  \param features same as above
00592      *  \param prob a n x class_count_ matrix. passed by reference to
00593      *  save class probabilities
00594      *  \param stop earlystopping criterion 
00595      *  \sa EarlyStopping
00596      */
00597     template <class U, class C1, class T, class C2, class Stop>
00598     void predictProbabilities(MultiArrayView<2, U, C1>const &   features,
00599                               MultiArrayView<2, T, C2> &        prob,
00600                               Stop                     &        stop) const;
00601     template <class T1,class T2, class C>
00602     void predictProbabilities(OnlinePredictionSet<T1> &  predictionSet,
00603                                MultiArrayView<2, T2, C> &       prob);
00604 
00605     /** \brief predict the class probabilities for multiple labels
00606      *
00607      *  \param features same as above
00608      *  \param prob a n x class_count_ matrix. passed by reference to
00609      *  save class probabilities
00610      */
00611     template <class U, class C1, class T, class C2>
00612     void predictProbabilities(MultiArrayView<2, U, C1>const &   features,
00613                               MultiArrayView<2, T, C2> &        prob)  const
00614     {
00615         predictProbabilities(features, prob, rf_default()); 
00616     }   
00617 
00618 
00619     /*\}*/
00620 
00621 };
00622 
00623 
00624 template <class LabelType, class PreprocessorTag>
00625 template<class U,class C1,
00626     class U2, class C2,
00627     class Split_t,
00628     class Stop_t,
00629     class Visitor_t,
00630     class Random_t>
00631 void RandomForest<LabelType, PreprocessorTag>::onlineLearn(MultiArrayView<2,U,C1> const & features,
00632                                                              MultiArrayView<2,U2,C2> const & response,
00633                                                              int new_start_index,
00634                                                              Visitor_t visitor_,
00635                                                              Split_t split_,
00636                                                              Stop_t stop_,
00637                                                              Random_t & random,
00638                                                              bool adjust_thresholds)
00639 {
00640     online_visitor_.activate();
00641     online_visitor_.adjust_thresholds=adjust_thresholds;
00642 
00643     using namespace rf;
00644     //typedefs
00645     typedef Processor<PreprocessorTag,LabelType,U,C1,U2,C2> Preprocessor_t;
00646     typedef          UniformIntRandomFunctor<Random_t>
00647                                                     RandFunctor_t;
00648     // default values and initialization
00649     // Value Chooser chooses second argument as value if first argument
00650     // is of type RF_DEFAULT. (thanks to template magic - don't care about
00651     // it - just smile and wave.
00652     
00653     #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_> 
00654     Default_Stop_t default_stop(options_);
00655     typename RF_CHOOSER(Stop_t)::type stop
00656             = RF_CHOOSER(Stop_t)::choose(stop_, default_stop); 
00657     Default_Split_t default_split;
00658     typename RF_CHOOSER(Split_t)::type split 
00659             = RF_CHOOSER(Split_t)::choose(split_, default_split); 
00660     rf::visitors::StopVisiting stopvisiting;
00661     typedef  rf::visitors::detail::VisitorNode
00662                 <rf::visitors::OnlineLearnVisitor, 
00663                  typename RF_CHOOSER(Visitor_t)::type> 
00664                                                         IntermedVis; 
00665     IntermedVis
00666         visitor(online_visitor_, RF_CHOOSER(Visitor_t)::choose(visitor_, stopvisiting));
00667     #undef RF_CHOOSER
00668 
00669     // Preprocess the data to get something the split functor can work
00670     // with. Also fill the ext_param structure by preprocessing
00671     // option parameters that could only be completely evaluated
00672     // when the training data is known.
00673     ext_param_.class_count_=0;
00674     Preprocessor_t preprocessor(    features, response,
00675                                     options_, ext_param_);
00676 
00677     // Make stl compatible random functor.
00678     RandFunctor_t           randint     ( random);
00679 
00680     // Give the Split functor information about the data.
00681     split.set_external_parameters(ext_param_);
00682     stop.set_external_parameters(ext_param_);
00683 
00684 
00685     //Create poisson samples
00686     PoissonSampler<RandomTT800> poisson_sampler(1.0,vigra::Int32(new_start_index),vigra::Int32(ext_param().row_count_));
00687 
00688     //TODO: visitors for online learning
00689     //visitor.visit_at_beginning(*this, preprocessor);
00690 
00691     // THE MAIN EFFING RF LOOP - YEAY DUDE!
00692     for(int ii = 0; ii < (int)trees_.size(); ++ii)
00693     {
00694         online_visitor_.tree_id=ii;
00695         poisson_sampler.sample();
00696         std::map<int,int> leaf_parents;
00697         leaf_parents.clear();
00698         //Get all the leaf nodes for that sample
00699         for(int s=0;s<poisson_sampler.numOfSamples();++s)
00700         {
00701             int sample=poisson_sampler[s];
00702             online_visitor_.current_label=preprocessor.response()(sample,0);
00703             online_visitor_.last_node_id=StackEntry_t::DecisionTreeNoParent;
00704             int leaf=trees_[ii].getToLeaf(rowVector(features,sample),online_visitor_);
00705 
00706 
00707             //Add to the list for that leaf
00708             online_visitor_.add_to_index_list(ii,leaf,sample);
00709             //TODO: Class count?
00710             //Store parent
00711             if(Node<e_ConstProbNode>(trees_[ii].topology_,trees_[ii].parameters_,leaf).prob_begin()[preprocessor.response()(sample,0)]!=1.0)
00712             {
00713                 leaf_parents[leaf]=online_visitor_.last_node_id;
00714             }
00715         }
00716 
00717 
00718         std::map<int,int>::iterator leaf_iterator;
00719         for(leaf_iterator=leaf_parents.begin();leaf_iterator!=leaf_parents.end();++leaf_iterator)
00720         {
00721             int leaf=leaf_iterator->first;
00722             int parent=leaf_iterator->second;
00723             int lin_index=online_visitor_.trees_online_information[ii].exterior_to_index[leaf];
00724             ArrayVector<Int32> indeces;
00725             indeces.clear();
00726             indeces.swap(online_visitor_.trees_online_information[ii].index_lists[lin_index]);
00727             StackEntry_t stack_entry(indeces.begin(),
00728                                      indeces.end(),
00729                                      ext_param_.class_count_);
00730 
00731 
00732             if(parent!=-1)
00733             {
00734                 if(NodeBase(trees_[ii].topology_,trees_[ii].parameters_,parent).child(0)==leaf)
00735                 {
00736                     stack_entry.leftParent=parent;
00737                 }
00738                 else
00739                 {
00740                     vigra_assert(NodeBase(trees_[ii].topology_,trees_[ii].parameters_,parent).child(1)==leaf,"last_node_id seems to be wrong");
00741                     stack_entry.rightParent=parent;
00742                 }
00743             }
00744             //trees_[ii].continueLearn(preprocessor.features(),preprocessor.response(),stack_entry,split,stop,visitor,randint,leaf);
00745             trees_[ii].continueLearn(preprocessor.features(),preprocessor.response(),stack_entry,split,stop,visitor,randint,-1);
00746             //Now, the last one moved onto leaf
00747             online_visitor_.move_exterior_node(ii,trees_[ii].topology_.size(),ii,leaf);
00748             //Now it should be classified correctly!
00749         }
00750 
00751         /*visitor
00752             .visit_after_tree(  *this,
00753                                 preprocessor,
00754                                 poisson_sampler,
00755                                 stack_entry,
00756                                 ii);*/
00757     }
00758 
00759     //visitor.visit_at_end(*this, preprocessor);
00760     online_visitor_.deactivate();
00761 }
00762 
00763 template<class LabelType, class PreprocessorTag>
00764 template<class U,class C1,
00765     class U2, class C2,
00766     class Split_t,
00767     class Stop_t,
00768     class Visitor_t,
00769     class Random_t>
00770 void RandomForest<LabelType, PreprocessorTag>::reLearnTree(MultiArrayView<2,U,C1> const & features,
00771                  MultiArrayView<2,U2,C2> const & response,
00772                  int treeId,
00773                  Visitor_t visitor_,
00774                  Split_t split_,
00775                  Stop_t stop_,
00776                  Random_t & random)
00777 {
00778     using namespace rf;
00779     
00780     
00781     typedef          UniformIntRandomFunctor<Random_t>
00782                                                     RandFunctor_t;
00783 
00784     // See rf_preprocessing.hxx for more info on this
00785     ext_param_.class_count_=0;
00786     typedef Processor<PreprocessorTag,LabelType, U, C1, U2, C2> Preprocessor_t;
00787     
00788     // default values and initialization
00789     // Value Chooser chooses second argument as value if first argument
00790     // is of type RF_DEFAULT. (thanks to template magic - don't care about
00791     // it - just smile and wave.
00792     
00793     #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_> 
00794     Default_Stop_t default_stop(options_);
00795     typename RF_CHOOSER(Stop_t)::type stop
00796             = RF_CHOOSER(Stop_t)::choose(stop_, default_stop); 
00797     Default_Split_t default_split;
00798     typename RF_CHOOSER(Split_t)::type split 
00799             = RF_CHOOSER(Split_t)::choose(split_, default_split); 
00800     rf::visitors::StopVisiting stopvisiting;
00801     typedef  rf::visitors::detail::VisitorNode
00802                 <rf::visitors::OnlineLearnVisitor, 
00803                 typename RF_CHOOSER(Visitor_t)::type> IntermedVis; 
00804     IntermedVis
00805         visitor(online_visitor_, RF_CHOOSER(Visitor_t)::choose(visitor_, stopvisiting));
00806     #undef RF_CHOOSER
00807     vigra_precondition(options_.prepare_online_learning_,"reLearnTree: Re learning trees only makes sense, if online learning is enabled");
00808     online_visitor_.activate();
00809 
00810     // Make stl compatible random functor.
00811     RandFunctor_t           randint     ( random);
00812 
00813     // Preprocess the data to get something the split functor can work
00814     // with. Also fill the ext_param structure by preprocessing
00815     // option parameters that could only be completely evaluated
00816     // when the training data is known.
00817     Preprocessor_t preprocessor(    features, response,
00818                                     options_, ext_param_);
00819 
00820     // Give the Split functor information about the data.
00821     split.set_external_parameters(ext_param_);
00822     stop.set_external_parameters(ext_param_);
00823 
00824     /**\todo    replace this crappy class out. It uses function pointers.
00825      *          and is making code slower according to me.
00826      *          Comment from Nathan: This is copied from Rahul, so me=Rahul
00827      */
00828     Sampler<Random_t > sampler(preprocessor.strata().begin(),
00829                                preprocessor.strata().end(),
00830                                detail::make_sampler_opt(options_)
00831                                         .sampleSize(ext_param().actual_msample_),
00832                                     random);
00833     //initialize First region/node/stack entry
00834     sampler
00835         .sample();
00836 
00837     StackEntry_t
00838         first_stack_entry(  sampler.sampledIndices().begin(),
00839                             sampler.sampledIndices().end(),
00840                             ext_param_.class_count_);
00841     first_stack_entry
00842         .set_oob_range(     sampler.oobIndices().begin(),
00843                             sampler.oobIndices().end());
00844     online_visitor_.reset_tree(treeId);
00845     online_visitor_.tree_id=treeId;
00846     trees_[treeId].reset();
00847     trees_[treeId]
00848         .learn( preprocessor.features(),
00849                 preprocessor.response(),
00850                 first_stack_entry,
00851                 split,
00852                 stop,
00853                 visitor,
00854                 randint);
00855     visitor
00856         .visit_after_tree(  *this,
00857                             preprocessor,
00858                             sampler,
00859                             first_stack_entry,
00860                             treeId);
00861 
00862     online_visitor_.deactivate();
00863 }
00864 
00865 template <class LabelType, class PreprocessorTag>
00866 template <class U, class C1,
00867          class U2,class C2,
00868          class Split_t,
00869          class Stop_t,
00870          class Visitor_t,
00871          class Random_t>
00872 void RandomForest<LabelType, PreprocessorTag>::
00873                      learn( MultiArrayView<2, U, C1> const  &   features,
00874                             MultiArrayView<2, U2,C2> const  &   response,
00875                             Visitor_t                           visitor_,
00876                             Split_t                             split_,
00877                             Stop_t                              stop_,
00878                             Random_t                 const  &   random)
00879 {
00880     using namespace rf;
00881     //this->reset();
00882     //typedefs
00883     typedef          UniformIntRandomFunctor<Random_t>
00884                                                     RandFunctor_t;
00885 
00886     // See rf_preprocessing.hxx for more info on this
00887     typedef Processor<PreprocessorTag,LabelType, U, C1, U2, C2> Preprocessor_t;
00888     
00889     // default values and initialization
00890     // Value Chooser chooses second argument as value if first argument
00891     // is of type RF_DEFAULT. (thanks to template magic - don't care about
00892     // it - just smile and wave.
00893     
00894     #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_> 
00895     Default_Stop_t default_stop(options_);
00896     typename RF_CHOOSER(Stop_t)::type stop
00897             = RF_CHOOSER(Stop_t)::choose(stop_, default_stop); 
00898     Default_Split_t default_split;
00899     typename RF_CHOOSER(Split_t)::type split 
00900             = RF_CHOOSER(Split_t)::choose(split_, default_split); 
00901     rf::visitors::StopVisiting stopvisiting;
00902     typedef  rf::visitors::detail::VisitorNode<
00903                 rf::visitors::OnlineLearnVisitor, 
00904                 typename RF_CHOOSER(Visitor_t)::type> IntermedVis; 
00905     IntermedVis
00906         visitor(online_visitor_, RF_CHOOSER(Visitor_t)::choose(visitor_, stopvisiting));
00907     #undef RF_CHOOSER
00908     if(options_.prepare_online_learning_)
00909         online_visitor_.activate();
00910     else
00911         online_visitor_.deactivate();
00912 
00913 
00914     // Make stl compatible random functor.
00915     RandFunctor_t           randint     ( random);
00916 
00917 
00918     // Preprocess the data to get something the split functor can work
00919     // with. Also fill the ext_param structure by preprocessing
00920     // option parameters that could only be completely evaluated
00921     // when the training data is known.
00922     Preprocessor_t preprocessor(    features, response,
00923                                     options_, ext_param_);
00924 
00925     // Give the Split functor information about the data.
00926     split.set_external_parameters(ext_param_);
00927     stop.set_external_parameters(ext_param_);
00928 
00929 
00930     //initialize trees.
00931     trees_.resize(options_.tree_count_  , DecisionTree_t(ext_param_));
00932 
00933     Sampler<Random_t > sampler(preprocessor.strata().begin(),
00934                                preprocessor.strata().end(),
00935                                detail::make_sampler_opt(options_)
00936                                         .sampleSize(ext_param().actual_msample_),
00937                                     random);
00938 
00939     visitor.visit_at_beginning(*this, preprocessor);
00940     // THE MAIN EFFING RF LOOP - YEAY DUDE!
00941     
00942     for(int ii = 0; ii < (int)trees_.size(); ++ii)
00943     {
00944         //initialize First region/node/stack entry
00945         sampler
00946             .sample();  
00947         StackEntry_t
00948             first_stack_entry(  sampler.sampledIndices().begin(),
00949                                 sampler.sampledIndices().end(),
00950                                 ext_param_.class_count_);
00951         first_stack_entry
00952             .set_oob_range(     sampler.oobIndices().begin(),
00953                                 sampler.oobIndices().end());
00954         trees_[ii]
00955             .learn(             preprocessor.features(),
00956                                 preprocessor.response(),
00957                                 first_stack_entry,
00958                                 split,
00959                                 stop,
00960                                 visitor,
00961                                 randint);
00962         visitor
00963             .visit_after_tree(  *this,
00964                                 preprocessor,
00965                                 sampler,
00966                                 first_stack_entry,
00967                                 ii);
00968     }
00969 
00970     visitor.visit_at_end(*this, preprocessor);
00971     // Only for online learning?
00972     online_visitor_.deactivate();
00973 }
00974 
00975 
00976 
00977 
00978 template <class LabelType, class Tag>
00979 template <class U, class C, class Stop>
00980 LabelType RandomForest<LabelType, Tag>
00981     ::predictLabel(MultiArrayView<2, U, C> const & features, Stop & stop) const
00982 {
00983     vigra_precondition(columnCount(features) >= ext_param_.column_count_,
00984         "RandomForestn::predictLabel():"
00985             " Too few columns in feature matrix.");
00986     vigra_precondition(rowCount(features) == 1,
00987         "RandomForestn::predictLabel():"
00988             " Feature matrix must have a singlerow.");
00989     typedef MultiArrayShape<2>::type Shp;
00990     garbage_prediction_.reshape(Shp(1, ext_param_.class_count_), 0.0);
00991     LabelType          d;
00992     predictProbabilities(features, garbage_prediction_, stop);
00993     ext_param_.to_classlabel(argMax(garbage_prediction_), d);
00994     return d;
00995 }
00996 
00997 
00998 //Same thing as above with priors for each label !!!
00999 template <class LabelType, class PreprocessorTag>
01000 template <class U, class C>
01001 LabelType RandomForest<LabelType, PreprocessorTag>
01002     ::predictLabel( MultiArrayView<2, U, C> const & features,
01003                     ArrayVectorView<double> priors) const
01004 {
01005     using namespace functor;
01006     vigra_precondition(columnCount(features) >= ext_param_.column_count_,
01007         "RandomForestn::predictLabel(): Too few columns in feature matrix.");
01008     vigra_precondition(rowCount(features) == 1,
01009         "RandomForestn::predictLabel():"
01010         " Feature matrix must have a single row.");
01011     Matrix<double>  prob(1,ext_param_.class_count_);
01012     predictProbabilities(features, prob);
01013     std::transform( prob.begin(), prob.end(),
01014                     priors.begin(), prob.begin(),
01015                     Arg1()*Arg2());
01016     LabelType          d;
01017     ext_param_.to_classlabel(argMax(prob), d);
01018     return d;
01019 }
01020 
01021 template<class LabelType,class PreprocessorTag>
01022 template <class T1,class T2, class C>
01023 void RandomForest<LabelType,PreprocessorTag>
01024     ::predictProbabilities(OnlinePredictionSet<T1> &  predictionSet,
01025                           MultiArrayView<2, T2, C> &       prob)
01026 {
01027     //Features are n xp
01028     //prob is n x NumOfLabel probaility for each feature in each class
01029     
01030     vigra_precondition(rowCount(predictionSet.features) == rowCount(prob),
01031                        "RandomFroest::predictProbabilities():"
01032                        " Feature matrix and probability matrix size misnmatch.");
01033     // num of features must be bigger than num of features in Random forest training
01034     // but why bigger?
01035     vigra_precondition( columnCount(predictionSet.features) >= ext_param_.column_count_,
01036       "RandomForestn::predictProbabilities():"
01037         " Too few columns in feature matrix.");
01038     vigra_precondition( columnCount(prob)
01039                         == (MultiArrayIndex)ext_param_.class_count_,
01040       "RandomForestn::predictProbabilities():"
01041       " Probability matrix must have as many columns as there are classes.");
01042     prob.init(0.0);
01043     //store total weights
01044     std::vector<T1> totalWeights(predictionSet.indices[0].size(),0.0);
01045     //Go through all trees
01046     int set_id=-1;
01047     for(int k=0; k<options_.tree_count_; ++k)
01048     {
01049         set_id=(set_id+1) % predictionSet.indices[0].size();
01050         typedef std::set<SampleRange<T1> > my_set;
01051         typedef typename my_set::iterator set_it;
01052         //typedef std::set<std::pair<int,SampleRange<T1> > >::iterator set_it;
01053         //Build a stack with all the ranges we have
01054         std::vector<std::pair<int,set_it> > stack;
01055         stack.clear();
01056         set_it i;
01057         for(i=predictionSet.ranges[set_id].begin();i!=predictionSet.ranges[set_id].end();++i)
01058             stack.push_back(std::pair<int,set_it>(2,i));
01059         //get weights predicted by single tree
01060         int num_decisions=0;
01061         while(!stack.empty())
01062         {
01063             set_it range=stack.back().second;
01064             int index=stack.back().first;
01065             stack.pop_back();
01066             ++num_decisions;
01067 
01068             if(trees_[k].isLeafNode(trees_[k].topology_[index]))
01069             {
01070                 ArrayVector<double>::iterator weights=Node<e_ConstProbNode>(trees_[k].topology_,
01071                                                                             trees_[k].parameters_,
01072                                                                             index).prob_begin();
01073                 for(int i=range->start;i!=range->end;++i)
01074                 {
01075                     //update votecount.
01076                     for(int l=0; l<ext_param_.class_count_; ++l)
01077                     {
01078                         prob(predictionSet.indices[set_id][i], l) += (T2)weights[l];
01079                         //every weight in totalWeight.
01080                         totalWeights[predictionSet.indices[set_id][i]] += (T1)weights[l];
01081                     }
01082                 }
01083             }
01084 
01085             else
01086             {
01087                 if(trees_[k].topology_[index]!=i_ThresholdNode)
01088                 {
01089                     throw std::runtime_error("predicting with online prediction sets is only supported for RFs with threshold nodes");
01090                 }
01091                 Node<i_ThresholdNode> node(trees_[k].topology_,trees_[k].parameters_,index);
01092                 if(range->min_boundaries[node.column()]>=node.threshold())
01093                 {
01094                     //Everything goes to right child
01095                     stack.push_back(std::pair<int,set_it>(node.child(1),range));
01096                     continue;
01097                 }
01098                 if(range->max_boundaries[node.column()]<node.threshold())
01099                 {
01100                     //Everything goes to the left child
01101                     stack.push_back(std::pair<int,set_it>(node.child(0),range));
01102                     continue;
01103                 }
01104                 //We have to split at this node
01105                 SampleRange<T1> new_range=*range;
01106                 new_range.min_boundaries[node.column()]=FLT_MAX;
01107                 range->max_boundaries[node.column()]=-FLT_MAX;
01108                 new_range.start=new_range.end=range->end;
01109                 int i=range->start;
01110                 while(i!=range->end)
01111                 {
01112                     //Decide for range->indices[i]
01113                     if(predictionSet.features(predictionSet.indices[set_id][i],node.column())>=node.threshold())
01114                     {
01115                         new_range.min_boundaries[node.column()]=std::min(new_range.min_boundaries[node.column()],
01116                                                                     predictionSet.features(predictionSet.indices[set_id][i],node.column()));
01117                         --range->end;
01118                         --new_range.start;
01119                         std::swap(predictionSet.indices[set_id][i],predictionSet.indices[set_id][range->end]);
01120 
01121                     }
01122                     else
01123                     {
01124                         range->max_boundaries[node.column()]=std::max(range->max_boundaries[node.column()],
01125                                                                  predictionSet.features(predictionSet.indices[set_id][i],node.column()));
01126                         ++i;
01127                     }
01128                 }
01129                 //The old one ...
01130                 if(range->start==range->end)
01131                 {
01132                     predictionSet.ranges[set_id].erase(range);
01133                 }
01134                 else
01135                 {
01136                     stack.push_back(std::pair<int,set_it>(node.child(0),range));
01137                 }
01138                 //And the new one ...
01139                 if(new_range.start!=new_range.end)
01140                 {
01141                     std::pair<set_it,bool> new_it=predictionSet.ranges[set_id].insert(new_range);
01142                     stack.push_back(std::pair<int,set_it>(node.child(1),new_it.first));
01143                 }
01144             }
01145         }
01146         predictionSet.cumulativePredTime[k]=num_decisions;
01147     }
01148     for(unsigned int i=0;i<totalWeights.size();++i)
01149     {
01150         double test=0.0;
01151         //Normalise votes in each row by total VoteCount (totalWeight
01152         for(int l=0; l<ext_param_.class_count_; ++l)
01153         {
01154             test+=prob(i,l);
01155             prob(i, l) /= totalWeights[i];
01156         }
01157         assert(test==totalWeights[i]);
01158         assert(totalWeights[i]>0.0);
01159     }
01160 }
01161 
01162 template <class LabelType, class PreprocessorTag>
01163 template <class U, class C1, class T, class C2, class Stop_t>
01164 void RandomForest<LabelType, PreprocessorTag>
01165     ::predictProbabilities(MultiArrayView<2, U, C1>const &  features,
01166                            MultiArrayView<2, T, C2> &       prob,
01167                            Stop_t                   &       stop_) const
01168 {
01169     //Features are n xp
01170     //prob is n x NumOfLabel probability for each feature in each class
01171 
01172     vigra_precondition(rowCount(features) == rowCount(prob),
01173       "RandomForestn::predictProbabilities():"
01174         " Feature matrix and probability matrix size mismatch.");
01175 
01176     // num of features must be bigger than num of features in Random forest training
01177     // but why bigger?
01178     vigra_precondition( columnCount(features) >= ext_param_.column_count_,
01179       "RandomForestn::predictProbabilities():"
01180         " Too few columns in feature matrix.");
01181     vigra_precondition( columnCount(prob)
01182                         == (MultiArrayIndex)ext_param_.class_count_,
01183       "RandomForestn::predictProbabilities():"
01184       " Probability matrix must have as many columns as there are classes.");
01185 
01186     #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_> 
01187     Default_Stop_t default_stop(options_);
01188     typename RF_CHOOSER(Stop_t)::type & stop
01189             = RF_CHOOSER(Stop_t)::choose(stop_, default_stop); 
01190     #undef RF_CHOOSER 
01191     stop.set_external_parameters(ext_param_, tree_count());
01192     prob.init(NumericTraits<T>::zero());
01193     /* This code was originally there for testing early stopping
01194      * - we wanted the order of the trees to be randomized
01195     if(tree_indices_.size() != 0)
01196     {
01197        std::random_shuffle(tree_indices_.begin(),
01198                            tree_indices_.end()); 
01199     }
01200     */
01201     //Classify for each row.
01202     for(int row=0; row < rowCount(features); ++row)
01203     {
01204         ArrayVector<double>::const_iterator weights;
01205 
01206         //totalWeight == totalVoteCount!
01207         double totalWeight = 0.0;
01208 
01209         //Let each tree classify...
01210         for(int k=0; k<options_.tree_count_; ++k)
01211         {
01212             //get weights predicted by single tree
01213             weights = trees_[k /*tree_indices_[k]*/].predict(rowVector(features, row));
01214 
01215             //update votecount.
01216             int weighted = options_.predict_weighted_;
01217             for(int l=0; l<ext_param_.class_count_; ++l)
01218             {
01219                 double cur_w = weights[l] * (weighted * (*(weights-1))
01220                                            + (1-weighted));
01221                 prob(row, l) += (T)cur_w;
01222                 //every weight in totalWeight.
01223                 totalWeight += cur_w;
01224             }
01225             if(stop.after_prediction(weights, 
01226                                      k,
01227                                      rowVector(prob, row),
01228                                      totalWeight))
01229             {
01230                 break;
01231             }
01232         }
01233 
01234         //Normalise votes in each row by total VoteCount (totalWeight
01235         for(int l=0; l< ext_param_.class_count_; ++l)
01236         {
01237             prob(row, l) /= detail::RequiresExplicitCast<T>::cast(totalWeight);
01238         }
01239     }
01240 
01241 }
01242 
01243 //@}
01244 
01245 } // namespace vigra
01246 
01247 #include "random_forest/rf_algorithm.hxx"
01248 #endif // VIGRA_RANDOM_FOREST_HXX

© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de)
Heidelberg Collaboratory for Image Processing, University of Heidelberg, Germany

html generated using doxygen and Python
vigra 1.7.1 (Sun Feb 19 2012)