[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]

random_forest.hxx
1 /************************************************************************/
2 /* */
3 /* Copyright 2008-2009 by Ullrich Koethe and Rahul Nair */
4 /* */
5 /* This file is part of the VIGRA computer vision library. */
6 /* The VIGRA Website is */
7 /* http://hci.iwr.uni-heidelberg.de/vigra/ */
8 /* Please direct questions, bug reports, and contributions to */
9 /* ullrich.koethe@iwr.uni-heidelberg.de or */
10 /* vigra@informatik.uni-hamburg.de */
11 /* */
12 /* Permission is hereby granted, free of charge, to any person */
13 /* obtaining a copy of this software and associated documentation */
14 /* files (the "Software"), to deal in the Software without */
15 /* restriction, including without limitation the rights to use, */
16 /* copy, modify, merge, publish, distribute, sublicense, and/or */
17 /* sell copies of the Software, and to permit persons to whom the */
18 /* Software is furnished to do so, subject to the following */
19 /* conditions: */
20 /* */
21 /* The above copyright notice and this permission notice shall be */
22 /* included in all copies or substantial portions of the */
23 /* Software. */
24 /* */
25 /* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND */
26 /* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES */
27 /* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND */
28 /* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT */
29 /* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, */
30 /* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING */
31 /* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR */
32 /* OTHER DEALINGS IN THE SOFTWARE. */
33 /* */
34 /************************************************************************/
35 
36 
37 #ifndef VIGRA_RANDOM_FOREST_HXX
38 #define VIGRA_RANDOM_FOREST_HXX
39 
40 #include <iostream>
41 #include <algorithm>
42 #include <map>
43 #include <set>
44 #include <list>
45 #include <numeric>
46 #include "mathutil.hxx"
47 #include "array_vector.hxx"
48 #include "sized_int.hxx"
49 #include "matrix.hxx"
50 #include "random.hxx"
51 #include "functorexpression.hxx"
52 #include "random_forest/rf_common.hxx"
53 #include "random_forest/rf_nodeproxy.hxx"
54 #include "random_forest/rf_split.hxx"
55 #include "random_forest/rf_decisionTree.hxx"
56 #include "random_forest/rf_visitors.hxx"
57 #include "random_forest/rf_region.hxx"
58 #include "sampling.hxx"
59 #include "random_forest/rf_preprocessing.hxx"
60 #include "random_forest/rf_online_prediction_set.hxx"
61 #include "random_forest/rf_earlystopping.hxx"
62 #include "random_forest/rf_ridge_split.hxx"
63 namespace vigra
64 {
65 
66 /** \addtogroup MachineLearning Machine Learning
67 
68  This module provides classification algorithms that map
69  features to labels or label probabilities.
70  Look at the RandomForest class first for a overview of most of the
71  functionality provided as well as use cases.
72 **/
73 //@{
74 
75 namespace detail
76 {
77 
78 
79 
80 /* \brief sampling option factory function
81  */
82 inline SamplerOptions make_sampler_opt ( RandomForestOptions & RF_opt)
83 {
84  SamplerOptions return_opt;
85  return_opt.withReplacement(RF_opt.sample_with_replacement_);
86  return_opt.stratified(RF_opt.stratification_method_ == RF_EQUAL);
87  return return_opt;
88 }
89 }//namespace detail
90 
91 /** Random Forest class
92  *
93  * \tparam <PrprocessorTag = ClassificationTag> Class used to preprocess
94  * the input while learning and predicting. Currently Available:
95  * ClassificationTag and RegressionTag. It is recommended to use
96  * Splitfunctor::Preprocessor_t while using custom splitfunctors
97  * as they may need the data to be in a different format.
98  * \sa Preprocessor
99  *
100  * simple usage for classification (regression is not yet supported):
101  * look at RandomForest::learn() as well as RandomForestOptions() for additional
102  * options.
103  *
104  * \code
105  * using namespace vigra;
106  * using namespace rf;
107  * typedef xxx feature_t; \\ replace xxx with whichever type
108  * typedef yyy label_t; \\ likewise
109  *
110  * // allocate the training data
111  * MultiArrayView<2, feature_t> f = get_training_features();
112  * MultiArrayView<2, label_t> l = get_training_labels();
113  *
114  * RandomForest<> rf;
115  *
116  * // construct visitor to calculate out-of-bag error
117  * visitors::OOB_Error oob_v;
118  *
119  * // perform training
120  * rf.learn(f, l, visitors::create_visitor(oob_v));
121  *
122  * std::cout << "the out-of-bag error is: " << oob_v.oob_breiman << "\n";
123  *
124  * // get features for new data to be used for prediction
125  * MultiArrayView<2, feature_t> pf = get_features();
126  *
127  * // allocate space for the response (pf.shape(0) is the number of samples)
128  * MultiArrayView<2, label_t> prediction(pf.shape(0), 1);
129  * MultiArrayView<2, double> prob(pf.shape(0), rf.class_count());
130  *
131  * // perform prediction on new data
132  * rf.predict_labels(pf, prediction);
133  * rf.predict_probabilities(pf, prob);
134  *
135  * \endcode
136  *
137  * Additional information such as Variable Importance measures are accessed
138  * via Visitors defined in rf::visitors.
139  * Have a look at rf::split for other splitting methods.
140  *
141 */
142 template <class LabelType = double , class PreprocessorTag = ClassificationTag >
144 {
145 
146  public:
147  //public typedefs
149  typedef detail::DecisionTree DecisionTree_t;
151  typedef GiniSplit Default_Split_t;
155  StackEntry_t;
156  typedef LabelType LabelT;
157  protected:
158 
159  /** optimisation for predictLabels
160  * */
162 
163  public:
164 
165  //problem independent data.
166  Options_t options_;
167  //problem dependent data members - is only set if
168  //a copy constructor, some sort of import
169  //function or the learn function is called
171  ProblemSpec_t ext_param_;
172  /*mutable ArrayVector<int> tree_indices_;*/
173  rf::visitors::OnlineLearnVisitor online_visitor_;
174 
175 
176  void reset()
177  {
178  ext_param_.clear();
179  trees_.clear();
180  }
181 
182  public:
183 
184  /** \name Constructors
185  * Note: No copy Constructor specified as no pointers are manipulated
186  * in this class
187  */
188  /*\{*/
189  /**\brief default constructor
190  *
191  * \param options general options to the Random Forest. Must be of Type
192  * Options_t
193  * \param ext_param problem specific values that can be supplied
194  * additionally. (class weights , labels etc)
195  * \sa RandomForestOptions, ProblemSpec
196  *
197  */
200  :
201  options_(options),
202  ext_param_(ext_param)/*,
203  tree_indices_(options.tree_count_,0)*/
204  {
205  /*for(int ii = 0 ; ii < int(tree_indices_.size()); ++ii)
206  tree_indices_[ii] = ii;*/
207  }
208 
209  /**\brief Create RF from external source
210  * \param treeCount Number of trees to add.
211  * \param topology_begin
212  * Iterator to a Container where the topology_ data
213  * of the trees are stored.
214  * Iterator should support at least treeCount forward
215  * iterations. (i.e. topology_end - topology_begin >= treeCount
216  * \param parameter_begin
217  * iterator to a Container where the parameters_ data
218  * of the trees are stored. Iterator should support at
219  * least treeCount forward iterations.
220  * \param problem_spec
221  * Extrinsic parameters that specify the problem e.g.
222  * ClassCount, featureCount etc.
223  * \param options (optional) specify options used to train the original
224  * Random forest. This parameter is not used anywhere
225  * during prediction and thus is optional.
226  *
227  */
228  /* TODO: This constructor may be replaced by a Constructor using
229  * NodeProxy iterators to encapsulate the underlying data type.
230  */
231  template<class TopologyIterator, class ParameterIterator>
232  RandomForest(int treeCount,
233  TopologyIterator topology_begin,
234  ParameterIterator parameter_begin,
235  ProblemSpec_t const & problem_spec,
236  Options_t const & options = Options_t())
237  :
238  trees_(treeCount, DecisionTree_t(problem_spec)),
239  ext_param_(problem_spec),
240  options_(options)
241  {
242  for(unsigned int k=0; k<treeCount; ++k, ++topology_begin, ++parameter_begin)
243  {
244  trees_[k].topology_ = *topology_begin;
245  trees_[k].parameters_ = *parameter_begin;
246  }
247  }
248 
249  /*\}*/
250 
251 
252  /** \name Data Access
253  * data access interface - usage of member variables is deprecated
254  */
255 
256  /*\{*/
257 
258 
259  /**\brief return external parameters for viewing
260  * \return ProblemSpec_t
261  */
262  ProblemSpec_t const & ext_param() const
263  {
264  vigra_precondition(ext_param_.used() == true,
265  "RandomForest::ext_param(): "
266  "Random forest has not been trained yet.");
267  return ext_param_;
268  }
269 
270  /**\brief set external parameters
271  *
272  * \param in external parameters to be set
273  *
274  * set external parameters explicitly.
275  * If Random Forest has not been trained the preprocessor will
276  * either ignore filling values set this way or will throw an exception
277  * if values specified manually do not match the value calculated
278  & during the preparation step.
279  */
280  void set_ext_param(ProblemSpec_t const & in)
281  {
282  vigra_precondition(ext_param_.used() == false,
283  "RandomForest::set_ext_param():"
284  "Random forest has been trained! Call reset()"
285  "before specifying new extrinsic parameters.");
286  }
287 
288  /**\brief access random forest options
289  *
290  * \return random forest options
291  */
293  {
294  return options;
295  }
296 
297 
298  /**\brief access const random forest options
299  *
300  * \return const Option_t
301  */
302  Options_t const & options() const
303  {
304  return options_;
305  }
306 
307  /**\brief access const trees
308  */
309  DecisionTree_t const & tree(int index) const
310  {
311  return trees_[index];
312  }
313 
314  /**\brief access trees
315  */
316  DecisionTree_t & tree(int index)
317  {
318  return trees_[index];
319  }
320 
321  /*\}*/
322 
323  /**\brief return number of features used while
324  * training.
325  */
326  int feature_count() const
327  {
328  return ext_param_.column_count_;
329  }
330 
331 
332  /**\brief return number of features used while
333  * training.
334  *
335  * deprecated. Use feature_count() instead.
336  */
337  int column_count() const
338  {
339  return ext_param_.column_count_;
340  }
341 
342  /**\brief return number of classes used while
343  * training.
344  */
345  int class_count() const
346  {
347  return ext_param_.class_count_;
348  }
349 
350  /**\brief return number of trees
351  */
352  int tree_count() const
353  {
354  return options_.tree_count_;
355  }
356 
357 
358 
359  template<class U,class C1,
360  class U2, class C2,
361  class Split_t,
362  class Stop_t,
363  class Visitor_t,
364  class Random_t>
365  void onlineLearn( MultiArrayView<2,U,C1> const & features,
366  MultiArrayView<2,U2,C2> const & response,
367  int new_start_index,
368  Visitor_t visitor_,
369  Split_t split_,
370  Stop_t stop_,
371  Random_t & random,
372  bool adjust_thresholds=false);
373 
374  template <class U, class C1, class U2,class C2>
375  void onlineLearn( MultiArrayView<2, U, C1> const & features,
376  MultiArrayView<2, U2,C2> const & labels,int new_start_index,bool adjust_thresholds=false)
377  {
379  onlineLearn(features,
380  labels,
381  new_start_index,
382  rf_default(),
383  rf_default(),
384  rf_default(),
385  rnd,
386  adjust_thresholds);
387  }
388 
389  template<class U,class C1,
390  class U2, class C2,
391  class Split_t,
392  class Stop_t,
393  class Visitor_t,
394  class Random_t>
395  void reLearnTree(MultiArrayView<2,U,C1> const & features,
396  MultiArrayView<2,U2,C2> const & response,
397  int treeId,
398  Visitor_t visitor_,
399  Split_t split_,
400  Stop_t stop_,
401  Random_t & random);
402 
403  template<class U, class C1, class U2, class C2>
404  void reLearnTree(MultiArrayView<2, U, C1> const & features,
405  MultiArrayView<2, U2, C2> const & labels,
406  int treeId)
407  {
408  RandomNumberGenerator<> rnd = RandomNumberGenerator<>(RandomSeed);
409  reLearnTree(features,
410  labels,
411  treeId,
412  rf_default(),
413  rf_default(),
414  rf_default(),
415  rnd);
416  }
417 
418 
419  /**\name Learning
420  * Following functions differ in the degree of customization
421  * allowed
422  */
423  /*\{*/
424  /**\brief learn on data with custom config and random number generator
425  *
426  * \param features a N x M matrix containing N samples with M
427  * features
428  * \param response a N x D matrix containing the corresponding
429  * response. Current split functors assume D to
430  * be 1 and ignore any additional columns.
431  * This is not enforced to allow future support
432  * for uncertain labels, label independent strata etc.
433  * The Preprocessor specified during construction
434  * should be able to handle features and labels
435  * features and the labels.
436  * see also: SplitFunctor, Preprocessing
437  *
438  * \param visitor visitor which is to be applied after each split,
439  * tree and at the end. Use rf_default for using
440  * default value. (No Visitors)
441  * see also: rf::visitors
442  * \param split split functor to be used to calculate each split
443  * use rf_default() for using default value. (GiniSplit)
444  * see also: rf::split
445  * \param stop
446  * predicate to be used to calculate each split
447  * use rf_default() for using default value. (EarlyStoppStd)
448  * \param random RandomNumberGenerator to be used. Use
449  * rf_default() to use default value.(RandomMT19337)
450  *
451  *
452  */
453  template <class U, class C1,
454  class U2,class C2,
455  class Split_t,
456  class Stop_t,
457  class Visitor_t,
458  class Random_t>
459  void learn( MultiArrayView<2, U, C1> const & features,
460  MultiArrayView<2, U2,C2> const & response,
461  Visitor_t visitor,
462  Split_t split,
463  Stop_t stop,
464  Random_t const & random);
465 
466  template <class U, class C1,
467  class U2,class C2,
468  class Split_t,
469  class Stop_t,
470  class Visitor_t>
471  void learn( MultiArrayView<2, U, C1> const & features,
472  MultiArrayView<2, U2,C2> const & response,
473  Visitor_t visitor,
474  Split_t split,
475  Stop_t stop)
476 
477  {
478  RandomNumberGenerator<> rnd = RandomNumberGenerator<>(RandomSeed);
479  learn( features,
480  response,
481  visitor,
482  split,
483  stop,
484  rnd);
485  }
486 
487  template <class U, class C1, class U2,class C2, class Visitor_t>
488  void learn( MultiArrayView<2, U, C1> const & features,
489  MultiArrayView<2, U2,C2> const & labels,
490  Visitor_t visitor)
491  {
492  learn( features,
493  labels,
494  visitor,
495  rf_default(),
496  rf_default());
497  }
498 
499  template <class U, class C1, class U2,class C2,
500  class Visitor_t, class Split_t>
501  void learn( MultiArrayView<2, U, C1> const & features,
502  MultiArrayView<2, U2,C2> const & labels,
503  Visitor_t visitor,
504  Split_t split)
505  {
506  learn( features,
507  labels,
508  visitor,
509  split,
510  rf_default());
511  }
512 
513  /**\brief learn on data with default configuration
514  *
515  * \param features a N x M matrix containing N samples with M
516  * features
517  * \param labels a N x D matrix containing the corresponding
518  * N labels. Current split functors assume D to
519  * be 1 and ignore any additional columns.
520  * this is not enforced to allow future support
521  * for uncertain labels.
522  *
523  * learning is done with:
524  *
525  * \sa rf::split, EarlyStoppStd
526  *
527  * - Randomly seeded random number generator
528  * - default gini split functor as described by Breiman
529  * - default The standard early stopping criterion
530  */
531  template <class U, class C1, class U2,class C2>
532  void learn( MultiArrayView<2, U, C1> const & features,
533  MultiArrayView<2, U2,C2> const & labels)
534  {
535  learn( features,
536  labels,
537  rf_default(),
538  rf_default(),
539  rf_default());
540  }
541  /*\}*/
542 
543 
544 
545  /**\name prediction
546  */
547  /*\{*/
548  /** \brief predict a label given a feature.
549  *
550  * \param features: a 1 by featureCount matrix containing
551  * data point to be predicted (this only works in
552  * classification setting)
553  * \param stop: early stopping criterion
554  * \return double value representing class. You can use the
555  * predictLabels() function together with the
556  * rf.external_parameter().class_type_ attribute
557  * to get back the same type used during learning.
558  */
559  template <class U, class C, class Stop>
560  LabelType predictLabel(MultiArrayView<2, U, C>const & features, Stop & stop) const;
561 
562  template <class U, class C>
563  LabelType predictLabel(MultiArrayView<2, U, C>const & features)
564  {
565  return predictLabel(features, rf_default());
566  }
567  /** \brief predict a label with features and class priors
568  *
569  * \param features: same as above.
570  * \param prior: iterator to prior weighting of classes
571  * \return sam as above.
572  */
573  template <class U, class C>
574  LabelType predictLabel(MultiArrayView<2, U, C> const & features,
575  ArrayVectorView<double> prior) const;
576 
577  /** \brief predict multiple labels with given features
578  *
579  * \param features: a n by featureCount matrix containing
580  * data point to be predicted (this only works in
581  * classification setting)
582  * \param labels: a n by 1 matrix passed by reference to store
583  * output.
584  */
585  template <class U, class C1, class T, class C2>
587  MultiArrayView<2, T, C2> & labels) const
588  {
589  vigra_precondition(features.shape(0) == labels.shape(0),
590  "RandomForest::predictLabels(): Label array has wrong size.");
591  for(int k=0; k<features.shape(0); ++k)
592  labels(k,0) = detail::RequiresExplicitCast<T>::cast(predictLabel(rowVector(features, k), rf_default()));
593  }
594 
595  template <class U, class C1, class T, class C2, class Stop>
596  void predictLabels(MultiArrayView<2, U, C1>const & features,
597  MultiArrayView<2, T, C2> & labels,
598  Stop & stop) const
599  {
600  vigra_precondition(features.shape(0) == labels.shape(0),
601  "RandomForest::predictLabels(): Label array has wrong size.");
602  for(int k=0; k<features.shape(0); ++k)
603  labels(k,0) = detail::RequiresExplicitCast<T>::cast(predictLabel(rowVector(features, k), stop));
604  }
605  /** \brief predict the class probabilities for multiple labels
606  *
607  * \param features same as above
608  * \param prob a n x class_count_ matrix. passed by reference to
609  * save class probabilities
610  * \param stop earlystopping criterion
611  * \sa EarlyStopping
612  */
613  template <class U, class C1, class T, class C2, class Stop>
614  void predictProbabilities(MultiArrayView<2, U, C1>const & features,
615  MultiArrayView<2, T, C2> & prob,
616  Stop & stop) const;
617  template <class T1,class T2, class C>
618  void predictProbabilities(OnlinePredictionSet<T1> & predictionSet,
619  MultiArrayView<2, T2, C> & prob);
620 
621  /** \brief predict the class probabilities for multiple labels
622  *
623  * \param features same as above
624  * \param prob a n x class_count_ matrix. passed by reference to
625  * save class probabilities
626  */
627  template <class U, class C1, class T, class C2>
629  MultiArrayView<2, T, C2> & prob) const
630  {
631  predictProbabilities(features, prob, rf_default());
632  }
633 
634  template <class U, class C1, class T, class C2>
635  void predictRaw(MultiArrayView<2, U, C1>const & features,
636  MultiArrayView<2, T, C2> & prob) const;
637 
638 
639  /*\}*/
640 
641 };
642 
643 
644 template <class LabelType, class PreprocessorTag>
645 template<class U,class C1,
646  class U2, class C2,
647  class Split_t,
648  class Stop_t,
649  class Visitor_t,
650  class Random_t>
651 void RandomForest<LabelType, PreprocessorTag>::onlineLearn(MultiArrayView<2,U,C1> const & features,
652  MultiArrayView<2,U2,C2> const & response,
653  int new_start_index,
654  Visitor_t visitor_,
655  Split_t split_,
656  Stop_t stop_,
657  Random_t & random,
658  bool adjust_thresholds)
659 {
660  online_visitor_.activate();
661  online_visitor_.adjust_thresholds=adjust_thresholds;
662 
663  using namespace rf;
664  //typedefs
665  typedef Processor<PreprocessorTag,LabelType,U,C1,U2,C2> Preprocessor_t;
666  typedef UniformIntRandomFunctor<Random_t>
667  RandFunctor_t;
668  // default values and initialization
669  // Value Chooser chooses second argument as value if first argument
670  // is of type RF_DEFAULT. (thanks to template magic - don't care about
671  // it - just smile and wave.
672 
673  #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
674  Default_Stop_t default_stop(options_);
675  typename RF_CHOOSER(Stop_t)::type stop
676  = RF_CHOOSER(Stop_t)::choose(stop_, default_stop);
677  Default_Split_t default_split;
678  typename RF_CHOOSER(Split_t)::type split
679  = RF_CHOOSER(Split_t)::choose(split_, default_split);
680  rf::visitors::StopVisiting stopvisiting;
681  typedef rf::visitors::detail::VisitorNode
682  <rf::visitors::OnlineLearnVisitor,
683  typename RF_CHOOSER(Visitor_t)::type>
684  IntermedVis;
685  IntermedVis
686  visitor(online_visitor_, RF_CHOOSER(Visitor_t)::choose(visitor_, stopvisiting));
687  #undef RF_CHOOSER
688 
689  // Preprocess the data to get something the split functor can work
690  // with. Also fill the ext_param structure by preprocessing
691  // option parameters that could only be completely evaluated
692  // when the training data is known.
693  ext_param_.class_count_=0;
694  Preprocessor_t preprocessor( features, response,
695  options_, ext_param_);
696 
697  // Make stl compatible random functor.
698  RandFunctor_t randint ( random);
699 
700  // Give the Split functor information about the data.
701  split.set_external_parameters(ext_param_);
702  stop.set_external_parameters(ext_param_);
703 
704 
705  //Create poisson samples
706  PoissonSampler<RandomTT800> poisson_sampler(1.0,vigra::Int32(new_start_index),vigra::Int32(ext_param().row_count_));
707 
708  //TODO: visitors for online learning
709  //visitor.visit_at_beginning(*this, preprocessor);
710 
711  // THE MAIN EFFING RF LOOP - YEAY DUDE!
712  for(int ii = 0; ii < (int)trees_.size(); ++ii)
713  {
714  online_visitor_.tree_id=ii;
715  poisson_sampler.sample();
716  std::map<int,int> leaf_parents;
717  leaf_parents.clear();
718  //Get all the leaf nodes for that sample
719  for(int s=0;s<poisson_sampler.numOfSamples();++s)
720  {
721  int sample=poisson_sampler[s];
722  online_visitor_.current_label=preprocessor.response()(sample,0);
723  online_visitor_.last_node_id=StackEntry_t::DecisionTreeNoParent;
724  int leaf=trees_[ii].getToLeaf(rowVector(features,sample),online_visitor_);
725 
726 
727  //Add to the list for that leaf
728  online_visitor_.add_to_index_list(ii,leaf,sample);
729  //TODO: Class count?
730  //Store parent
731  if(Node<e_ConstProbNode>(trees_[ii].topology_,trees_[ii].parameters_,leaf).prob_begin()[preprocessor.response()(sample,0)]!=1.0)
732  {
733  leaf_parents[leaf]=online_visitor_.last_node_id;
734  }
735  }
736 
737 
738  std::map<int,int>::iterator leaf_iterator;
739  for(leaf_iterator=leaf_parents.begin();leaf_iterator!=leaf_parents.end();++leaf_iterator)
740  {
741  int leaf=leaf_iterator->first;
742  int parent=leaf_iterator->second;
743  int lin_index=online_visitor_.trees_online_information[ii].exterior_to_index[leaf];
744  ArrayVector<Int32> indeces;
745  indeces.clear();
746  indeces.swap(online_visitor_.trees_online_information[ii].index_lists[lin_index]);
747  StackEntry_t stack_entry(indeces.begin(),
748  indeces.end(),
749  ext_param_.class_count_);
750 
751 
752  if(parent!=-1)
753  {
754  if(NodeBase(trees_[ii].topology_,trees_[ii].parameters_,parent).child(0)==leaf)
755  {
756  stack_entry.leftParent=parent;
757  }
758  else
759  {
760  vigra_assert(NodeBase(trees_[ii].topology_,trees_[ii].parameters_,parent).child(1)==leaf,"last_node_id seems to be wrong");
761  stack_entry.rightParent=parent;
762  }
763  }
764  //trees_[ii].continueLearn(preprocessor.features(),preprocessor.response(),stack_entry,split,stop,visitor,randint,leaf);
765  trees_[ii].continueLearn(preprocessor.features(),preprocessor.response(),stack_entry,split,stop,visitor,randint,-1);
766  //Now, the last one moved onto leaf
767  online_visitor_.move_exterior_node(ii,trees_[ii].topology_.size(),ii,leaf);
768  //Now it should be classified correctly!
769  }
770 
771  /*visitor
772  .visit_after_tree( *this,
773  preprocessor,
774  poisson_sampler,
775  stack_entry,
776  ii);*/
777  }
778 
779  //visitor.visit_at_end(*this, preprocessor);
780  online_visitor_.deactivate();
781 }
782 
783 template<class LabelType, class PreprocessorTag>
784 template<class U,class C1,
785  class U2, class C2,
786  class Split_t,
787  class Stop_t,
788  class Visitor_t,
789  class Random_t>
791  MultiArrayView<2,U2,C2> const & response,
792  int treeId,
793  Visitor_t visitor_,
794  Split_t split_,
795  Stop_t stop_,
796  Random_t & random)
797 {
798  using namespace rf;
799 
800 
802  RandFunctor_t;
803 
804  // See rf_preprocessing.hxx for more info on this
805  ext_param_.class_count_=0;
807 
808  // default values and initialization
809  // Value Chooser chooses second argument as value if first argument
810  // is of type RF_DEFAULT. (thanks to template magic - don't care about
811  // it - just smile and wave.
812 
813  #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
814  Default_Stop_t default_stop(options_);
815  typename RF_CHOOSER(Stop_t)::type stop
816  = RF_CHOOSER(Stop_t)::choose(stop_, default_stop);
817  Default_Split_t default_split;
818  typename RF_CHOOSER(Split_t)::type split
819  = RF_CHOOSER(Split_t)::choose(split_, default_split);
820  rf::visitors::StopVisiting stopvisiting;
823  typename RF_CHOOSER(Visitor_t)::type> IntermedVis;
824  IntermedVis
825  visitor(online_visitor_, RF_CHOOSER(Visitor_t)::choose(visitor_, stopvisiting));
826  #undef RF_CHOOSER
827  vigra_precondition(options_.prepare_online_learning_,"reLearnTree: Re learning trees only makes sense, if online learning is enabled");
828  online_visitor_.activate();
829 
830  // Make stl compatible random functor.
831  RandFunctor_t randint ( random);
832 
833  // Preprocess the data to get something the split functor can work
834  // with. Also fill the ext_param structure by preprocessing
835  // option parameters that could only be completely evaluated
836  // when the training data is known.
837  Preprocessor_t preprocessor( features, response,
838  options_, ext_param_);
839 
840  // Give the Split functor information about the data.
841  split.set_external_parameters(ext_param_);
842  stop.set_external_parameters(ext_param_);
843 
844  /**\todo replace this crappy class out. It uses function pointers.
845  * and is making code slower according to me.
846  * Comment from Nathan: This is copied from Rahul, so me=Rahul
847  */
848  Sampler<Random_t > sampler(preprocessor.strata().begin(),
849  preprocessor.strata().end(),
850  detail::make_sampler_opt(options_)
851  .sampleSize(ext_param().actual_msample_),
852  &random);
853  //initialize First region/node/stack entry
854  sampler
855  .sample();
856 
858  first_stack_entry( sampler.sampledIndices().begin(),
859  sampler.sampledIndices().end(),
860  ext_param_.class_count_);
861  first_stack_entry
862  .set_oob_range( sampler.oobIndices().begin(),
863  sampler.oobIndices().end());
864  online_visitor_.reset_tree(treeId);
865  online_visitor_.tree_id=treeId;
866  trees_[treeId].reset();
867  trees_[treeId]
868  .learn( preprocessor.features(),
869  preprocessor.response(),
870  first_stack_entry,
871  split,
872  stop,
873  visitor,
874  randint);
875  visitor
876  .visit_after_tree( *this,
877  preprocessor,
878  sampler,
879  first_stack_entry,
880  treeId);
881 
882  online_visitor_.deactivate();
883 }
884 
885 template <class LabelType, class PreprocessorTag>
886 template <class U, class C1,
887  class U2,class C2,
888  class Split_t,
889  class Stop_t,
890  class Visitor_t,
891  class Random_t>
894  MultiArrayView<2, U2,C2> const & response,
895  Visitor_t visitor_,
896  Split_t split_,
897  Stop_t stop_,
898  Random_t const & random)
899 {
900  using namespace rf;
901  //this->reset();
902  //typedefs
904  RandFunctor_t;
905 
906  // See rf_preprocessing.hxx for more info on this
908 
909  vigra_precondition(features.shape(0) == response.shape(0),
910  "RandomForest::learn(): shape mismatch between features and response.");
911 
912  // default values and initialization
913  // Value Chooser chooses second argument as value if first argument
914  // is of type RF_DEFAULT. (thanks to template magic - don't care about
915  // it - just smile and wave.
916 
917  #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
918  Default_Stop_t default_stop(options_);
919  typename RF_CHOOSER(Stop_t)::type stop
920  = RF_CHOOSER(Stop_t)::choose(stop_, default_stop);
921  Default_Split_t default_split;
922  typename RF_CHOOSER(Split_t)::type split
923  = RF_CHOOSER(Split_t)::choose(split_, default_split);
924  rf::visitors::StopVisiting stopvisiting;
927  typename RF_CHOOSER(Visitor_t)::type> IntermedVis;
928  IntermedVis
929  visitor(online_visitor_, RF_CHOOSER(Visitor_t)::choose(visitor_, stopvisiting));
930  #undef RF_CHOOSER
931  if(options_.prepare_online_learning_)
932  online_visitor_.activate();
933  else
934  online_visitor_.deactivate();
935 
936 
937  // Make stl compatible random functor.
938  RandFunctor_t randint ( random);
939 
940 
941  // Preprocess the data to get something the split functor can work
942  // with. Also fill the ext_param structure by preprocessing
943  // option parameters that could only be completely evaluated
944  // when the training data is known.
945  Preprocessor_t preprocessor( features, response,
946  options_, ext_param_);
947 
948  // Give the Split functor information about the data.
949  split.set_external_parameters(ext_param_);
950  stop.set_external_parameters(ext_param_);
951 
952 
953  //initialize trees.
954  trees_.resize(options_.tree_count_ , DecisionTree_t(ext_param_));
955 
956  Sampler<Random_t > sampler(preprocessor.strata().begin(),
957  preprocessor.strata().end(),
958  detail::make_sampler_opt(options_)
959  .sampleSize(ext_param().actual_msample_),
960  &random);
961 
962  visitor.visit_at_beginning(*this, preprocessor);
963  // THE MAIN EFFING RF LOOP - YEAY DUDE!
964 
965  for(int ii = 0; ii < (int)trees_.size(); ++ii)
966  {
967  //initialize First region/node/stack entry
968  sampler
969  .sample();
971  first_stack_entry( sampler.sampledIndices().begin(),
972  sampler.sampledIndices().end(),
973  ext_param_.class_count_);
974  first_stack_entry
975  .set_oob_range( sampler.oobIndices().begin(),
976  sampler.oobIndices().end());
977  trees_[ii]
978  .learn( preprocessor.features(),
979  preprocessor.response(),
980  first_stack_entry,
981  split,
982  stop,
983  visitor,
984  randint);
985  visitor
986  .visit_after_tree( *this,
987  preprocessor,
988  sampler,
989  first_stack_entry,
990  ii);
991  }
992 
993  visitor.visit_at_end(*this, preprocessor);
994  // Only for online learning?
995  online_visitor_.deactivate();
996 }
997 
998 
999 
1000 
1001 template <class LabelType, class Tag>
1002 template <class U, class C, class Stop>
1004  ::predictLabel(MultiArrayView<2, U, C> const & features, Stop & stop) const
1005 {
1006  vigra_precondition(columnCount(features) >= ext_param_.column_count_,
1007  "RandomForestn::predictLabel():"
1008  " Too few columns in feature matrix.");
1009  vigra_precondition(rowCount(features) == 1,
1010  "RandomForestn::predictLabel():"
1011  " Feature matrix must have a singlerow.");
1012  typedef MultiArrayShape<2>::type Shp;
1013  garbage_prediction_.reshape(Shp(1, ext_param_.class_count_), 0.0);
1014  LabelType d;
1015  predictProbabilities(features, garbage_prediction_, stop);
1016  ext_param_.to_classlabel(argMax(garbage_prediction_), d);
1017  return d;
1018 }
1019 
1020 
1021 //Same thing as above with priors for each label !!!
1022 template <class LabelType, class PreprocessorTag>
1023 template <class U, class C>
1026  ArrayVectorView<double> priors) const
1027 {
1028  using namespace functor;
1029  vigra_precondition(columnCount(features) >= ext_param_.column_count_,
1030  "RandomForestn::predictLabel(): Too few columns in feature matrix.");
1031  vigra_precondition(rowCount(features) == 1,
1032  "RandomForestn::predictLabel():"
1033  " Feature matrix must have a single row.");
1034  Matrix<double> prob(1,ext_param_.class_count_);
1035  predictProbabilities(features, prob);
1036  std::transform( prob.begin(), prob.end(),
1037  priors.begin(), prob.begin(),
1038  Arg1()*Arg2());
1039  LabelType d;
1040  ext_param_.to_classlabel(argMax(prob), d);
1041  return d;
1042 }
1043 
1044 template<class LabelType,class PreprocessorTag>
1045 template <class T1,class T2, class C>
1047  ::predictProbabilities(OnlinePredictionSet<T1> & predictionSet,
1048  MultiArrayView<2, T2, C> & prob)
1049 {
1050  //Features are n xp
1051  //prob is n x NumOfLabel probability for each feature in each class
1052 
1053  vigra_precondition(rowCount(predictionSet.features) == rowCount(prob),
1054  "RandomFroest::predictProbabilities():"
1055  " Feature matrix and probability matrix size mismatch.");
1056  // num of features must be bigger than num of features in Random forest training
1057  // but why bigger?
1058  vigra_precondition( columnCount(predictionSet.features) >= ext_param_.column_count_,
1059  "RandomForestn::predictProbabilities():"
1060  " Too few columns in feature matrix.");
1061  vigra_precondition( columnCount(prob)
1062  == (MultiArrayIndex)ext_param_.class_count_,
1063  "RandomForestn::predictProbabilities():"
1064  " Probability matrix must have as many columns as there are classes.");
1065  prob.init(0.0);
1066  //store total weights
1067  std::vector<T1> totalWeights(predictionSet.indices[0].size(),0.0);
1068  //Go through all trees
1069  int set_id=-1;
1070  for(int k=0; k<options_.tree_count_; ++k)
1071  {
1072  set_id=(set_id+1) % predictionSet.indices[0].size();
1073  typedef std::set<SampleRange<T1> > my_set;
1074  typedef typename my_set::iterator set_it;
1075  //typedef std::set<std::pair<int,SampleRange<T1> > >::iterator set_it;
1076  //Build a stack with all the ranges we have
1077  std::vector<std::pair<int,set_it> > stack;
1078  stack.clear();
1079  for(set_it i=predictionSet.ranges[set_id].begin();
1080  i!=predictionSet.ranges[set_id].end();++i)
1081  stack.push_back(std::pair<int,set_it>(2,i));
1082  //get weights predicted by single tree
1083  int num_decisions=0;
1084  while(!stack.empty())
1085  {
1086  set_it range=stack.back().second;
1087  int index=stack.back().first;
1088  stack.pop_back();
1089  ++num_decisions;
1090 
1091  if(trees_[k].isLeafNode(trees_[k].topology_[index]))
1092  {
1093  ArrayVector<double>::iterator weights=Node<e_ConstProbNode>(trees_[k].topology_,
1094  trees_[k].parameters_,
1095  index).prob_begin();
1096  for(int i=range->start;i!=range->end;++i)
1097  {
1098  //update votecount.
1099  for(int l=0; l<ext_param_.class_count_; ++l)
1100  {
1101  prob(predictionSet.indices[set_id][i], l) += (T2)weights[l];
1102  //every weight in totalWeight.
1103  totalWeights[predictionSet.indices[set_id][i]] += (T1)weights[l];
1104  }
1105  }
1106  }
1107 
1108  else
1109  {
1110  if(trees_[k].topology_[index]!=i_ThresholdNode)
1111  {
1112  throw std::runtime_error("predicting with online prediction sets is only supported for RFs with threshold nodes");
1113  }
1114  Node<i_ThresholdNode> node(trees_[k].topology_,trees_[k].parameters_,index);
1115  if(range->min_boundaries[node.column()]>=node.threshold())
1116  {
1117  //Everything goes to right child
1118  stack.push_back(std::pair<int,set_it>(node.child(1),range));
1119  continue;
1120  }
1121  if(range->max_boundaries[node.column()]<node.threshold())
1122  {
1123  //Everything goes to the left child
1124  stack.push_back(std::pair<int,set_it>(node.child(0),range));
1125  continue;
1126  }
1127  //We have to split at this node
1128  SampleRange<T1> new_range=*range;
1129  new_range.min_boundaries[node.column()]=FLT_MAX;
1130  range->max_boundaries[node.column()]=-FLT_MAX;
1131  new_range.start=new_range.end=range->end;
1132  int i=range->start;
1133  while(i!=range->end)
1134  {
1135  //Decide for range->indices[i]
1136  if(predictionSet.features(predictionSet.indices[set_id][i],node.column())>=node.threshold())
1137  {
1138  new_range.min_boundaries[node.column()]=std::min(new_range.min_boundaries[node.column()],
1139  predictionSet.features(predictionSet.indices[set_id][i],node.column()));
1140  --range->end;
1141  --new_range.start;
1142  std::swap(predictionSet.indices[set_id][i],predictionSet.indices[set_id][range->end]);
1143 
1144  }
1145  else
1146  {
1147  range->max_boundaries[node.column()]=std::max(range->max_boundaries[node.column()],
1148  predictionSet.features(predictionSet.indices[set_id][i],node.column()));
1149  ++i;
1150  }
1151  }
1152  //The old one ...
1153  if(range->start==range->end)
1154  {
1155  predictionSet.ranges[set_id].erase(range);
1156  }
1157  else
1158  {
1159  stack.push_back(std::pair<int,set_it>(node.child(0),range));
1160  }
1161  //And the new one ...
1162  if(new_range.start!=new_range.end)
1163  {
1164  std::pair<set_it,bool> new_it=predictionSet.ranges[set_id].insert(new_range);
1165  stack.push_back(std::pair<int,set_it>(node.child(1),new_it.first));
1166  }
1167  }
1168  }
1169  predictionSet.cumulativePredTime[k]=num_decisions;
1170  }
1171  for(unsigned int i=0;i<totalWeights.size();++i)
1172  {
1173  double test=0.0;
1174  //Normalise votes in each row by total VoteCount (totalWeight
1175  for(int l=0; l<ext_param_.class_count_; ++l)
1176  {
1177  test+=prob(i,l);
1178  prob(i, l) /= totalWeights[i];
1179  }
1180  assert(test==totalWeights[i]);
1181  assert(totalWeights[i]>0.0);
1182  }
1183 }
1184 
1185 template <class LabelType, class PreprocessorTag>
1186 template <class U, class C1, class T, class C2, class Stop_t>
1188  ::predictProbabilities(MultiArrayView<2, U, C1>const & features,
1189  MultiArrayView<2, T, C2> & prob,
1190  Stop_t & stop_) const
1191 {
1192  //Features are n xp
1193  //prob is n x NumOfLabel probability for each feature in each class
1194 
1195  vigra_precondition(rowCount(features) == rowCount(prob),
1196  "RandomForestn::predictProbabilities():"
1197  " Feature matrix and probability matrix size mismatch.");
1198 
1199  // num of features must be bigger than num of features in Random forest training
1200  // but why bigger?
1201  vigra_precondition( columnCount(features) >= ext_param_.column_count_,
1202  "RandomForestn::predictProbabilities():"
1203  " Too few columns in feature matrix.");
1204  vigra_precondition( columnCount(prob)
1205  == (MultiArrayIndex)ext_param_.class_count_,
1206  "RandomForestn::predictProbabilities():"
1207  " Probability matrix must have as many columns as there are classes.");
1208 
1209  #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
1210  Default_Stop_t default_stop(options_);
1211  typename RF_CHOOSER(Stop_t)::type & stop
1212  = RF_CHOOSER(Stop_t)::choose(stop_, default_stop);
1213  #undef RF_CHOOSER
1214  stop.set_external_parameters(ext_param_, tree_count());
1215  prob.init(NumericTraits<T>::zero());
1216  /* This code was originally there for testing early stopping
1217  * - we wanted the order of the trees to be randomized
1218  if(tree_indices_.size() != 0)
1219  {
1220  std::random_shuffle(tree_indices_.begin(),
1221  tree_indices_.end());
1222  }
1223  */
1224  //Classify for each row.
1225  for(int row=0; row < rowCount(features); ++row)
1226  {
1227  ArrayVector<double>::const_iterator weights;
1228 
1229  //totalWeight == totalVoteCount!
1230  double totalWeight = 0.0;
1231 
1232  //Let each tree classify...
1233  for(int k=0; k<options_.tree_count_; ++k)
1234  {
1235  //get weights predicted by single tree
1236  weights = trees_[k /*tree_indices_[k]*/].predict(rowVector(features, row));
1237 
1238  //update votecount.
1239  int weighted = options_.predict_weighted_;
1240  for(int l=0; l<ext_param_.class_count_; ++l)
1241  {
1242  double cur_w = weights[l] * (weighted * (*(weights-1))
1243  + (1-weighted));
1244  prob(row, l) += (T)cur_w;
1245  //every weight in totalWeight.
1246  totalWeight += cur_w;
1247  }
1248  if(stop.after_prediction(weights,
1249  k,
1250  rowVector(prob, row),
1251  totalWeight))
1252  {
1253  break;
1254  }
1255  }
1256 
1257  //Normalise votes in each row by total VoteCount (totalWeight
1258  for(int l=0; l< ext_param_.class_count_; ++l)
1259  {
1260  prob(row, l) /= detail::RequiresExplicitCast<T>::cast(totalWeight);
1261  }
1262  }
1263 
1264 }
1265 
1266 template <class LabelType, class PreprocessorTag>
1267 template <class U, class C1, class T, class C2>
1268 void RandomForest<LabelType, PreprocessorTag>
1269  ::predictRaw(MultiArrayView<2, U, C1>const & features,
1270  MultiArrayView<2, T, C2> & prob) const
1271 {
1272  //Features are n xp
1273  //prob is n x NumOfLabel probability for each feature in each class
1274 
1275  vigra_precondition(rowCount(features) == rowCount(prob),
1276  "RandomForestn::predictProbabilities():"
1277  " Feature matrix and probability matrix size mismatch.");
1278 
1279  // num of features must be bigger than num of features in Random forest training
1280  // but why bigger?
1281  vigra_precondition( columnCount(features) >= ext_param_.column_count_,
1282  "RandomForestn::predictProbabilities():"
1283  " Too few columns in feature matrix.");
1284  vigra_precondition( columnCount(prob)
1285  == (MultiArrayIndex)ext_param_.class_count_,
1286  "RandomForestn::predictProbabilities():"
1287  " Probability matrix must have as many columns as there are classes.");
1288 
1289  #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
1290  prob.init(NumericTraits<T>::zero());
1291  /* This code was originally there for testing early stopping
1292  * - we wanted the order of the trees to be randomized
1293  if(tree_indices_.size() != 0)
1294  {
1295  std::random_shuffle(tree_indices_.begin(),
1296  tree_indices_.end());
1297  }
1298  */
1299  //Classify for each row.
1300  for(int row=0; row < rowCount(features); ++row)
1301  {
1302  ArrayVector<double>::const_iterator weights;
1303 
1304  //totalWeight == totalVoteCount!
1305  double totalWeight = 0.0;
1306 
1307  //Let each tree classify...
1308  for(int k=0; k<options_.tree_count_; ++k)
1309  {
1310  //get weights predicted by single tree
1311  weights = trees_[k /*tree_indices_[k]*/].predict(rowVector(features, row));
1312 
1313  //update votecount.
1314  int weighted = options_.predict_weighted_;
1315  for(int l=0; l<ext_param_.class_count_; ++l)
1316  {
1317  double cur_w = weights[l] * (weighted * (*(weights-1))
1318  + (1-weighted));
1319  prob(row, l) += (T)cur_w;
1320  //every weight in totalWeight.
1321  totalWeight += cur_w;
1322  }
1323  }
1324  }
1325  prob/= options_.tree_count_;
1326 
1327 }
1328 
1329 //@}
1330 
1331 } // namespace vigra
1332 
1333 #include "random_forest/rf_algorithm.hxx"
1334 #endif // VIGRA_RANDOM_FOREST_HXX

© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de)
Heidelberg Collaboratory for Image Processing, University of Heidelberg, Germany

html generated using doxygen and Python
vigra 1.9.0 (Tue Oct 22 2013)