[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]

vigra/random_forest_deprec.hxx
00001 /************************************************************************/
00002 /*                                                                      */
00003 /*                  Copyright 2008 by Ullrich Koethe                    */
00004 /*                                                                      */
00005 /*    This file is part of the VIGRA computer vision library.           */
00006 /*    The VIGRA Website is                                              */
00007 /*        http://hci.iwr.uni-heidelberg.de/vigra/                       */
00008 /*    Please direct questions, bug reports, and contributions to        */
00009 /*        ullrich.koethe@iwr.uni-heidelberg.de    or                    */
00010 /*        vigra@informatik.uni-hamburg.de                               */
00011 /*                                                                      */
00012 /*    Permission is hereby granted, free of charge, to any person       */
00013 /*    obtaining a copy of this software and associated documentation    */
00014 /*    files (the "Software"), to deal in the Software without           */
00015 /*    restriction, including without limitation the rights to use,      */
00016 /*    copy, modify, merge, publish, distribute, sublicense, and/or      */
00017 /*    sell copies of the Software, and to permit persons to whom the    */
00018 /*    Software is furnished to do so, subject to the following          */
00019 /*    conditions:                                                       */
00020 /*                                                                      */
00021 /*    The above copyright notice and this permission notice shall be    */
00022 /*    included in all copies or substantial portions of the             */
00023 /*    Software.                                                         */
00024 /*                                                                      */
00025 /*    THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND    */
00026 /*    EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES   */
00027 /*    OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND          */
00028 /*    NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT       */
00029 /*    HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,      */
00030 /*    WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING      */
00031 /*    FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR     */
00032 /*    OTHER DEALINGS IN THE SOFTWARE.                                   */
00033 /*                                                                      */
00034 /************************************************************************/
00035 
00036 
00037 #ifndef VIGRA_RANDOM_FOREST_HXX
00038 #define VIGRA_RANDOM_FOREST_HXX
00039 
00040 #include <algorithm>
00041 #include <map>
00042 #include <numeric>
00043 #include <iostream>
00044 #include "vigra/mathutil.hxx"
00045 #include "vigra/array_vector.hxx"
00046 #include "vigra/sized_int.hxx"
00047 #include "vigra/matrix.hxx"
00048 #include "vigra/random.hxx"
00049 #include "vigra/functorexpression.hxx"
00050 
00051 #define RandomForest RandomForestDeprec
00052 #define DecisionTree DecisionTreeDeprec
00053 
00054 namespace vigra
00055 {
00056 
00057 /** \addtogroup MachineLearning 
00058 **/
00059 //@{
00060 
00061 namespace detail
00062 {
00063 
00064 template<class DataMatrix>
00065 class RandomForestFeatureSorter
00066 {
00067     DataMatrix const & data_;
00068     MultiArrayIndex sortColumn_;
00069 
00070   public:
00071 
00072     RandomForestFeatureSorter(DataMatrix const & data, MultiArrayIndex sortColumn)
00073     : data_(data),
00074       sortColumn_(sortColumn)
00075     {}
00076 
00077     void setColumn(MultiArrayIndex sortColumn)
00078     {
00079         sortColumn_ = sortColumn;
00080     }
00081 
00082     bool operator()(MultiArrayIndex l, MultiArrayIndex r) const
00083     {
00084         return data_(l, sortColumn_) < data_(r, sortColumn_);
00085     }
00086 };
00087 
00088 template<class LabelArray>
00089 class RandomForestLabelSorter
00090 {
00091     LabelArray const & labels_;
00092 
00093   public:
00094 
00095     RandomForestLabelSorter(LabelArray const & labels)
00096     : labels_(labels)
00097     {}
00098 
00099     bool operator()(MultiArrayIndex l, MultiArrayIndex r) const
00100     {
00101         return labels_[l] < labels_[r];
00102     }
00103 };
00104 
00105 template <class CountArray>
00106 class RandomForestClassCounter
00107 {
00108     ArrayVector<int> const & labels_;
00109     CountArray & counts_;
00110 
00111   public:
00112 
00113     RandomForestClassCounter(ArrayVector<int> const & labels, CountArray & counts)
00114     : labels_(labels),
00115       counts_(counts)
00116     {
00117         reset();
00118     }
00119 
00120     void reset()
00121     {
00122         counts_.init(0);
00123     }
00124 
00125     void operator()(MultiArrayIndex l) const
00126     {
00127         ++counts_[labels_[l]];
00128     }
00129 };
00130 
00131 struct DecisionTreeCountNonzeroFunctor
00132 {
00133     double operator()(double old, double other) const
00134     {
00135         if(other != 0.0)
00136             ++old;
00137         return old;
00138     }
00139 };
00140 
00141 struct DecisionTreeNode
00142 {
00143     DecisionTreeNode(int t, MultiArrayIndex bestColumn)
00144     : thresholdIndex(t), splitColumn(bestColumn)
00145     {}
00146 
00147     int children[2];
00148     int thresholdIndex;
00149     Int32 splitColumn;
00150 };
00151 
00152 template <class INT>
00153 struct DecisionTreeNodeProxy
00154 {
00155     DecisionTreeNodeProxy(ArrayVector<INT> const & tree, INT n)
00156     : node(const_cast<ArrayVector<INT> &>(tree).begin()+n)
00157     {}
00158 
00159     INT & child(INT l) const
00160     {
00161         return node[l];
00162     }
00163 
00164     INT & decisionWeightsIndex() const
00165     {
00166         return node[2];
00167     }
00168 
00169     typename ArrayVector<INT>::iterator decisionColumns() const
00170     {
00171         return node+3;
00172     }
00173 
00174     mutable typename ArrayVector<INT>::iterator node;
00175 };
00176 
00177 struct DecisionTreeAxisSplitFunctor
00178 {
00179     ArrayVector<Int32> splitColumns;
00180     ArrayVector<double> classCounts, currentCounts[2], bestCounts[2], classWeights;
00181     double threshold;
00182     double totalCounts[2], bestTotalCounts[2];
00183     int mtry, classCount, bestSplitColumn;
00184     bool pure[2], isWeighted;
00185 
00186     void init(int mtry, int cols, int classCount, ArrayVector<double> const & weights)
00187     {
00188         this->mtry = mtry;
00189         splitColumns.resize(cols);
00190         for(int k=0; k<cols; ++k)
00191             splitColumns[k] = k;
00192 
00193         this->classCount = classCount;
00194         classCounts.resize(classCount);
00195         currentCounts[0].resize(classCount);
00196         currentCounts[1].resize(classCount);
00197         bestCounts[0].resize(classCount);
00198         bestCounts[1].resize(classCount);
00199 
00200         isWeighted = weights.size() > 0;
00201         if(isWeighted)
00202             classWeights = weights;
00203         else
00204             classWeights.resize(classCount, 1.0);
00205     }
00206 
00207     bool isPure(int k) const
00208     {
00209         return pure[k];
00210     }
00211 
00212     unsigned int totalCount(int k) const
00213     {
00214         return (unsigned int)bestTotalCounts[k];
00215     }
00216 
00217     int sizeofNode() const { return 4; }
00218 
00219     int writeSplitParameters(ArrayVector<Int32> & tree,
00220                                 ArrayVector<double> &terminalWeights)
00221     {
00222         int currentWeightIndex = terminalWeights.size();
00223         terminalWeights.push_back(threshold);
00224 
00225         int currentNodeIndex = tree.size();
00226         tree.push_back(-1);  // left child
00227         tree.push_back(-1);  // right child
00228         tree.push_back(currentWeightIndex);
00229         tree.push_back(bestSplitColumn);
00230 
00231         return currentNodeIndex;
00232     }
00233 
00234     void writeWeights(int l, ArrayVector<double> &terminalWeights)
00235     {
00236         for(int k=0; k<classCount; ++k)
00237             terminalWeights.push_back(isWeighted
00238                                            ? bestCounts[l][k]
00239                                            : bestCounts[l][k] / totalCount(l));
00240     }
00241 
00242     template <class U, class C, class AxesIterator, class WeightIterator>
00243     bool decideAtNode(MultiArrayView<2, U, C> const & features,
00244                       AxesIterator a, WeightIterator w) const
00245     {
00246         return (features(0, *a) < *w);
00247     }
00248 
00249     template <class U, class C, class IndexIterator, class Random>
00250     IndexIterator findBestSplit(MultiArrayView<2, U, C> const & features,
00251                                 ArrayVector<int> const & labels,
00252                                 IndexIterator indices, int exampleCount,
00253                                 Random & randint);
00254 
00255 };
00256 
00257 
00258 template <class U, class C, class IndexIterator, class Random>
00259 IndexIterator
00260 DecisionTreeAxisSplitFunctor::findBestSplit(MultiArrayView<2, U, C> const & features,
00261                                             ArrayVector<int> const & labels,
00262                                             IndexIterator indices, int exampleCount,
00263                                             Random & randint)
00264 {
00265     // select columns to be tried for split
00266     for(int k=0; k<mtry; ++k)
00267         std::swap(splitColumns[k], splitColumns[k+randint(columnCount(features)-k)]);
00268 
00269     RandomForestFeatureSorter<MultiArrayView<2, U, C> > sorter(features, 0);
00270     RandomForestClassCounter<ArrayVector<double> > counter(labels, classCounts);
00271     std::for_each(indices, indices+exampleCount, counter);
00272 
00273     // find the best gini index
00274     double minGini = NumericTraits<double>::max();
00275     IndexIterator bestSplit;
00276     for(int k=0; k<mtry; ++k)
00277     {
00278         sorter.setColumn(splitColumns[k]);
00279         std::sort(indices, indices+exampleCount, sorter);
00280 
00281         currentCounts[0].init(0);
00282         std::transform(classCounts.begin(), classCounts.end(), classWeights.begin(),
00283                        currentCounts[1].begin(), std::multiplies<double>());
00284         totalCounts[0] = 0;
00285         totalCounts[1] = std::accumulate(currentCounts[1].begin(), currentCounts[1].end(), 0.0);
00286         for(int m = 0; m < exampleCount-1; ++m)
00287         {
00288             int label = labels[indices[m]];
00289             double w = classWeights[label];
00290             currentCounts[0][label] += w;
00291             totalCounts[0] += w;
00292             currentCounts[1][label] -= w;
00293             totalCounts[1] -= w;
00294 
00295             if (m < exampleCount-2 &&
00296                 features(indices[m], splitColumns[k]) == features(indices[m+1], splitColumns[k]))
00297                 continue ;
00298 
00299             double gini = 0.0;
00300             if(classCount == 2)
00301             {
00302                 gini = currentCounts[0][0]*currentCounts[0][1] / totalCounts[0] +
00303                        currentCounts[1][0]*currentCounts[1][1] / totalCounts[1];
00304             }
00305             else
00306             {
00307                 for(int l=0; l<classCount; ++l)
00308                     gini += currentCounts[0][l]*(1.0 - currentCounts[0][l] / totalCounts[0]) +
00309                             currentCounts[1][l]*(1.0 - currentCounts[1][l] / totalCounts[1]);
00310             }
00311             if(gini < minGini)
00312             {
00313                 minGini = gini;
00314                 bestSplit = indices+m;
00315                 bestSplitColumn = splitColumns[k];
00316                 bestCounts[0] = currentCounts[0];
00317                 bestCounts[1] = currentCounts[1];
00318             }
00319         }
00320 
00321 
00322 
00323     }
00324         //std::cerr << minGini << " " << bestSplitColumn << std::endl;
00325     // split using the best feature
00326     sorter.setColumn(bestSplitColumn);
00327     std::sort(indices, indices+exampleCount, sorter);
00328 
00329     for(int k=0; k<2; ++k)
00330     {
00331         bestTotalCounts[k] = std::accumulate(bestCounts[k].begin(), bestCounts[k].end(), 0.0);
00332     }
00333 
00334     threshold = (features(bestSplit[0], bestSplitColumn) + features(bestSplit[1], bestSplitColumn)) / 2.0;
00335     ++bestSplit;
00336 
00337     counter.reset();
00338     std::for_each(indices, bestSplit, counter);
00339     pure[0] = 1.0 == std::accumulate(classCounts.begin(), classCounts.end(), 0.0, DecisionTreeCountNonzeroFunctor());
00340     counter.reset();
00341     std::for_each(bestSplit, indices+exampleCount, counter);
00342     pure[1] = 1.0 == std::accumulate(classCounts.begin(), classCounts.end(), 0.0, DecisionTreeCountNonzeroFunctor());
00343 
00344     return bestSplit;
00345 }
00346 
00347 enum  { DecisionTreeNoParent = -1 };
00348 
00349 template <class Iterator>
00350 struct DecisionTreeStackEntry
00351 {
00352     DecisionTreeStackEntry(Iterator i, int c,
00353                            int lp = DecisionTreeNoParent, int rp = DecisionTreeNoParent)
00354     : indices(i), exampleCount(c),
00355       leftParent(lp), rightParent(rp)
00356     {}
00357 
00358     Iterator indices;
00359     int exampleCount, leftParent, rightParent;
00360 };
00361 
00362 class DecisionTree
00363 {
00364   public:
00365     typedef Int32 TreeInt;
00366     ArrayVector<TreeInt>  tree_;
00367     ArrayVector<double> terminalWeights_;
00368     unsigned int classCount_;
00369     DecisionTreeAxisSplitFunctor split;
00370 
00371   public:
00372 
00373 
00374     DecisionTree(unsigned int classCount)
00375     : classCount_(classCount)
00376     {}
00377 
00378     void reset(unsigned int classCount = 0)
00379     {
00380         if(classCount)
00381             classCount_ = classCount;
00382         tree_.clear();
00383         terminalWeights_.clear();
00384     }
00385 
00386     template <class U, class C, class Iterator, class Options, class Random>
00387     void learn(MultiArrayView<2, U, C> const & features,
00388                ArrayVector<int> const & labels,
00389                Iterator indices, int exampleCount,
00390                Options const & options,
00391                Random & randint);
00392 
00393     template <class U, class C>
00394     ArrayVector<double>::const_iterator
00395     predict(MultiArrayView<2, U, C> const & features) const
00396     {
00397         int nodeindex = 0;
00398         for(;;)
00399         {
00400             DecisionTreeNodeProxy<TreeInt> node(tree_, nodeindex);
00401             nodeindex = split.decideAtNode(features, node.decisionColumns(),
00402                                        terminalWeights_.begin() + node.decisionWeightsIndex())
00403                                 ? node.child(0)
00404                                 : node.child(1);
00405             if(nodeindex <= 0)
00406                 return terminalWeights_.begin() + (-nodeindex);
00407         }
00408     }
00409 
00410     template <class U, class C>
00411     int
00412     predictLabel(MultiArrayView<2, U, C> const & features) const
00413     {
00414         ArrayVector<double>::const_iterator weights = predict(features);
00415         return argMax(weights, weights+classCount_) - weights;
00416     }
00417 
00418     template <class U, class C>
00419     int
00420     leafID(MultiArrayView<2, U, C> const & features) const
00421     {
00422         int nodeindex = 0;
00423         for(;;)
00424         {
00425             DecisionTreeNodeProxy<TreeInt> node(tree_, nodeindex);
00426             nodeindex = split.decideAtNode(features, node.decisionColumns(),
00427                                        terminalWeights_.begin() + node.decisionWeightsIndex())
00428                                 ? node.child(0)
00429                                 : node.child(1);
00430             if(nodeindex <= 0)
00431                 return -nodeindex;
00432         }
00433     }
00434 
00435     void depth(int & maxDep, int & interiorCount, int & leafCount, int k = 0, int d = 1) const
00436     {
00437         DecisionTreeNodeProxy<TreeInt> node(tree_, k);
00438         ++interiorCount;
00439         ++d;
00440         for(int l=0; l<2; ++l)
00441         {
00442             int child = node.child(l);
00443             if(child > 0)
00444                 depth(maxDep, interiorCount, leafCount, child, d);
00445             else
00446             {
00447                 ++leafCount;
00448                 if(maxDep < d)
00449                     maxDep = d;
00450             }
00451         }
00452     }
00453 
00454     void printStatistics(std::ostream & o) const
00455     {
00456         int maxDep = 0, interiorCount = 0, leafCount = 0;
00457         depth(maxDep, interiorCount, leafCount);
00458 
00459         o << "interior nodes: " << interiorCount <<
00460              ", terminal nodes: " << leafCount <<
00461              ", depth: " << maxDep << "\n";
00462     }
00463 
00464     void print(std::ostream & o, int k = 0, std::string s = "") const
00465     {
00466         DecisionTreeNodeProxy<TreeInt> node(tree_, k);
00467         o << s << (*node.decisionColumns()) << " " << terminalWeights_[node.decisionWeightsIndex()] << "\n";
00468 
00469         for(int l=0; l<2; ++l)
00470         {
00471             int child = node.child(l);
00472             if(child <= 0)
00473                 o << s << " weights " << terminalWeights_[-child] << " "
00474                                       << terminalWeights_[-child+1] << "\n";
00475             else
00476                 print(o, child, s+" ");
00477         }
00478     }
00479 };
00480 
00481 
00482 template <class U, class C, class Iterator, class Options, class Random>
00483 void DecisionTree::learn(MultiArrayView<2, U, C> const & features,
00484                           ArrayVector<int> const & labels,
00485                           Iterator indices, int exampleCount,
00486                           Options const & options,
00487                           Random & randint)
00488 {
00489     ArrayVector<double> const & classLoss = options.class_weights;
00490 
00491     vigra_precondition(classLoss.size() == 0 || classLoss.size() == classCount_,
00492         "DecisionTree2::learn(): class weights array has wrong size.");
00493 
00494     reset();
00495 
00496     unsigned int mtry = options.mtry;
00497     MultiArrayIndex cols = columnCount(features);
00498 
00499     split.init(mtry, cols, classCount_, classLoss);
00500 
00501     typedef DecisionTreeStackEntry<Iterator> Entry;
00502     ArrayVector<Entry> stack;
00503     stack.push_back(Entry(indices, exampleCount));
00504 
00505     while(!stack.empty())
00506     {
00507 //        std::cerr << "*";
00508         indices = stack.back().indices;
00509         exampleCount = stack.back().exampleCount;
00510         int leftParent  = stack.back().leftParent,
00511             rightParent = stack.back().rightParent;
00512 
00513         stack.pop_back();
00514 
00515         Iterator bestSplit = split.findBestSplit(features, labels, indices, exampleCount, randint);
00516 
00517 
00518         int currentNode = split.writeSplitParameters(tree_, terminalWeights_);
00519 
00520         if(leftParent != DecisionTreeNoParent)
00521             DecisionTreeNodeProxy<TreeInt>(tree_, leftParent).child(0) = currentNode;
00522         if(rightParent != DecisionTreeNoParent)
00523             DecisionTreeNodeProxy<TreeInt>(tree_, rightParent).child(1) = currentNode;
00524         leftParent = currentNode;
00525         rightParent = DecisionTreeNoParent;
00526 
00527         for(int l=0; l<2; ++l)
00528         {
00529 
00530             if(!split.isPure(l) && split.totalCount(l) >= options.min_split_node_size)
00531             {
00532                 // sample is still large enough and not yet perfectly separated => split
00533                 stack.push_back(Entry(indices, split.totalCount(l), leftParent, rightParent));
00534             }
00535             else
00536             {
00537                 DecisionTreeNodeProxy<TreeInt>(tree_, currentNode).child(l) = -(TreeInt)terminalWeights_.size();
00538 
00539                 split.writeWeights(l, terminalWeights_);
00540             }
00541             std::swap(leftParent, rightParent);
00542             indices = bestSplit;
00543         }
00544     }
00545 //    std::cerr << "\n";
00546 }
00547 
00548 } // namespace detail
00549 
00550 class RandomForestOptions
00551 {
00552   public:
00553         /** Initialize all options with default values.
00554         */
00555     RandomForestOptions()
00556     : training_set_proportion(1.0),
00557       mtry(0),
00558       min_split_node_size(1),
00559       training_set_size(0),
00560       sample_with_replacement(true),
00561       sample_classes_individually(false),
00562       treeCount(255)
00563     {}
00564 
00565         /** Number of features considered in each node.
00566 
00567             If \a n is 0 (the default), the number of features tried in every node
00568             is determined by the square root of the total number of features.
00569             According to Breiman, this quantity should slways be optimized by means
00570             of the out-of-bag error.<br>
00571             Default: 0 (use <tt>sqrt(columnCount(featureMatrix))</tt>)
00572         */
00573     RandomForestOptions & featuresPerNode(unsigned int n)
00574     {
00575         mtry = n;
00576         return *this;
00577     }
00578 
00579         /** How to sample the subset of the training data for each tree.
00580 
00581             Each tree is only trained with a subset of the entire training data.
00582             If \a r is <tt>true</tt>, this subset is sampled from the entire training set with
00583             replacement.<br>
00584             Default: <tt>true</tt> (use sampling with replacement))
00585         */
00586     RandomForestOptions & sampleWithReplacement(bool r)
00587     {
00588         sample_with_replacement = r;
00589         return *this;
00590     }
00591 
00592     RandomForestOptions & setTreeCount(unsigned int cnt)
00593     {
00594         treeCount = cnt;
00595         return *this;
00596     }
00597         /** Proportion of training examples used for each tree.
00598 
00599             If \a p is 1.0 (the default), and samples are drawn with replacement,
00600             the training set of each tree will contain as many examples as the entire
00601             training set, but some are drawn multiply and others not at all. On average,
00602             each tree is actually trained on about 65% of the examples in the full
00603             training set. Changing the proportion makes mainly sense when
00604             sampleWithReplacement() is set to <tt>false</tt>. trainingSetSizeProportional() gets
00605             overridden by trainingSetSizeAbsolute().<br>
00606             Default: 1.0
00607         */
00608     RandomForestOptions & trainingSetSizeProportional(double p)
00609     {
00610         vigra_precondition(p >= 0.0 && p <= 1.0,
00611             "RandomForestOptions::trainingSetSizeProportional(): proportion must be in [0, 1].");
00612         if(training_set_size == 0) // otherwise, absolute size gets priority
00613             training_set_proportion = p;
00614         return *this;
00615     }
00616 
00617         /** Size of the training set for each tree.
00618 
00619             If this option is set, it overrides the proportion set by
00620             trainingSetSizeProportional(). When classes are sampled individually,
00621             the number of examples is divided by the number of classes (rounded upwards)
00622             to determine the number of examples drawn from every class.<br>
00623             Default: <tt>0</tt> (determine size by proportion)
00624         */
00625     RandomForestOptions & trainingSetSizeAbsolute(unsigned int s)
00626     {
00627         training_set_size = s;
00628         if(s > 0)
00629             training_set_proportion = 0.0;
00630         return *this;
00631     }
00632 
00633         /** Are the classes sampled individually?
00634 
00635             If \a s is <tt>false</tt> (the default), the training set for each tree is sampled
00636             without considering class labels. Otherwise, samples are drawn from each
00637             class independently. The latter is especially useful in connection
00638             with the specification of an absolute training set size: then, the same number of
00639             examples is drawn from every class. This can be used as a counter-measure when the
00640             classes are very unbalanced in size.<br>
00641             Default: <tt>false</tt>
00642         */
00643     RandomForestOptions & sampleClassesIndividually(bool s)
00644     {
00645         sample_classes_individually = s;
00646         return *this;
00647     }
00648 
00649         /** Number of examples required for a node to be split.
00650 
00651             When the number of examples in a node is below this number, the node is not
00652             split even if class separation is not yet perfect. Instead, the node returns
00653             the proportion of each class (among the remaining examples) during the
00654             prediction phase.<br>
00655             Default: 1 (complete growing)
00656         */
00657     RandomForestOptions & minSplitNodeSize(unsigned int n)
00658     {
00659         if(n == 0)
00660             n = 1;
00661         min_split_node_size = n;
00662         return *this;
00663     }
00664 
00665         /** Use a weighted random forest.
00666 
00667             This is usually used to penalize the errors for the minority class.
00668             Weights must be convertible to <tt>double</tt>, and the array of weights
00669             must contain as many entries as there are classes.<br>
00670             Default: do not use weights
00671         */
00672     template <class WeightIterator>
00673     RandomForestOptions & weights(WeightIterator weights, unsigned int classCount)
00674     {
00675         class_weights.clear();
00676         if(weights != 0)
00677             class_weights.insert(weights, classCount);
00678         return *this;
00679     }
00680 
00681     RandomForestOptions & oobData(MultiArrayView<2, UInt8>& data)
00682     {
00683         oob_data =data;
00684         return *this;
00685     }
00686 
00687     MultiArrayView<2, UInt8> oob_data;
00688     ArrayVector<double> class_weights;
00689     double training_set_proportion;
00690     unsigned int mtry, min_split_node_size, training_set_size;
00691     bool sample_with_replacement, sample_classes_individually;
00692     unsigned int treeCount;
00693 };
00694 
00695 /*****************************************************************/
00696 /*                                                               */
00697 /*                          RandomForest                         */
00698 /*                                                               */
00699 /*****************************************************************/
00700 
00701 template <class ClassLabelType>
00702 class RandomForest
00703 {
00704   public:
00705     ArrayVector<ClassLabelType> classes_;
00706     ArrayVector<detail::DecisionTree> trees_;
00707     MultiArrayIndex columnCount_;
00708     RandomForestOptions options_;
00709 
00710   public:
00711 
00712     //First two constructors are straight forward.
00713     //they take either the iterators to an Array of Classlabels or the values
00714     template<class ClassLabelIterator>
00715     RandomForest(ClassLabelIterator cl, ClassLabelIterator cend,
00716                   unsigned int treeCount = 255,
00717                   RandomForestOptions const & options = RandomForestOptions())
00718     : classes_(cl, cend),
00719       trees_(treeCount, detail::DecisionTree(classes_.size())),
00720       columnCount_(0),
00721       options_(options)
00722     {
00723         vigra_precondition(options.training_set_proportion == 0.0 ||
00724                            options.training_set_size == 0,
00725             "RandomForestOptions: absolute and proprtional training set sizes "
00726             "cannot be specified at the same time.");
00727         vigra_precondition(classes_.size() > 1,
00728             "RandomForestOptions::weights(): need at least two classes.");
00729         vigra_precondition(options.class_weights.size() == 0 || options.class_weights.size() == classes_.size(),
00730             "RandomForestOptions::weights(): wrong number of classes.");
00731     }
00732 
00733     RandomForest(ClassLabelType const & c1, ClassLabelType const & c2,
00734                   unsigned int treeCount = 255,
00735                   RandomForestOptions const & options = RandomForestOptions())
00736     : classes_(2),
00737       trees_(treeCount, detail::DecisionTree(2)),
00738       columnCount_(0),
00739       options_(options)
00740     {
00741         vigra_precondition(options.class_weights.size() == 0 || options.class_weights.size() == 2,
00742             "RandomForestOptions::weights(): wrong number of classes.");
00743         classes_[0] = c1;
00744         classes_[1] = c2;
00745     }
00746     //This is esp. For the CrosValidator Class
00747     template<class ClassLabelIterator>
00748     RandomForest(ClassLabelIterator cl, ClassLabelIterator cend,
00749                   RandomForestOptions const & options )
00750     : classes_(cl, cend),
00751       trees_(options.treeCount , detail::DecisionTree(classes_.size())),
00752       columnCount_(0),
00753       options_(options)
00754     {
00755 
00756         vigra_precondition(options.training_set_proportion == 0.0 ||
00757                            options.training_set_size == 0,
00758             "RandomForestOptions: absolute and proprtional training set sizes "
00759             "cannot be specified at the same time.");
00760         vigra_precondition(classes_.size() > 1,
00761             "RandomForestOptions::weights(): need at least two classes.");
00762         vigra_precondition(options.class_weights.size() == 0 || options.class_weights.size() == classes_.size(),
00763             "RandomForestOptions::weights(): wrong number of classes.");
00764     }
00765 
00766     //Not understood yet
00767     //Does not use the options object but the columnCount object.
00768     template<class ClassLabelIterator, class TreeIterator, class WeightIterator>
00769     RandomForest(ClassLabelIterator cl, ClassLabelIterator cend,
00770                   unsigned int treeCount, unsigned int columnCount,
00771                   TreeIterator trees, WeightIterator weights)
00772     : classes_(cl, cend),
00773       trees_(treeCount, detail::DecisionTree(classes_.size())),
00774       columnCount_(columnCount)
00775     {
00776         for(unsigned int k=0; k<treeCount; ++k, ++trees, ++weights)
00777         {
00778             trees_[k].tree_ = *trees;
00779             trees_[k].terminalWeights_ = *weights;
00780         }
00781     }
00782 
00783     int featureCount() const
00784     {
00785         vigra_precondition(columnCount_ > 0,
00786            "RandomForest::featureCount(): Random forest has not been trained yet.");
00787         return columnCount_;
00788     }
00789 
00790     int labelCount() const
00791     {
00792         return classes_.size();
00793     }
00794 
00795     int treeCount() const
00796     {
00797         return trees_.size();
00798     }
00799 
00800     // loss == 0.0 means unweighted random forest
00801     template <class U, class C, class Array, class Random>
00802     double learn(MultiArrayView<2, U, C> const & features, Array const & labels,
00803                Random const& random);
00804 
00805     template <class U, class C, class Array>
00806     double learn(MultiArrayView<2, U, C> const & features, Array const & labels)
00807     {
00808         return learn(features, labels, RandomTT800::global());
00809     }
00810 
00811     template <class U, class C>
00812     ClassLabelType predictLabel(MultiArrayView<2, U, C> const & features) const;
00813 
00814     template <class U, class C1, class T, class C2>
00815     void predictLabels(MultiArrayView<2, U, C1> const & features,
00816                        MultiArrayView<2, T, C2> & labels) const
00817     {
00818         vigra_precondition(features.shape(0) == labels.shape(0),
00819             "RandomForest::predictLabels(): Label array has wrong size.");
00820         for(int k=0; k<features.shape(0); ++k)
00821             labels(k,0) = predictLabel(rowVector(features, k));
00822     }
00823 
00824     template <class U, class C, class Iterator>
00825     ClassLabelType predictLabel(MultiArrayView<2, U, C> const & features,
00826                                 Iterator priors) const;
00827 
00828     template <class U, class C1, class T, class C2>
00829     void predictProbabilities(MultiArrayView<2, U, C1> const & features,
00830                               MultiArrayView<2, T, C2> & prob) const;
00831 
00832     template <class U, class C1, class T, class C2>
00833     void predictNodes(MultiArrayView<2, U, C1> const & features,
00834                                                    MultiArrayView<2, T, C2> & NodeIDs) const;
00835 };
00836 
00837 template <class ClassLabelType>
00838 template <class U, class C1, class Array, class Random>
00839 double
00840 RandomForest<ClassLabelType>::learn(MultiArrayView<2, U, C1> const & features,
00841                                              Array const & labels,
00842                                              Random const& random)
00843 {
00844     unsigned int classCount = classes_.size();
00845     unsigned int m = rowCount(features);
00846     unsigned int n = columnCount(features);
00847     vigra_precondition((unsigned int)(m) == (unsigned int)labels.size(),
00848       "RandomForest::learn(): Label array has wrong size.");
00849 
00850     vigra_precondition(options_.training_set_size <= m || options_.sample_with_replacement,
00851        "RandomForest::learn(): Requested training set size exceeds total number of examples.");
00852 
00853     MultiArrayIndex mtry = (options_.mtry == 0)
00854                                 ? int(std::floor(std::sqrt(double(n)) + 0.5))
00855                                 : options_.mtry;
00856 
00857     vigra_precondition(mtry <= (MultiArrayIndex)n,
00858        "RandomForest::learn(): mtry must be less than number of features.");
00859 
00860     MultiArrayIndex msamples = options_.training_set_size;
00861     if(options_.sample_classes_individually)
00862         msamples = int(std::ceil(double(msamples) / classCount));
00863 
00864     ArrayVector<int> intLabels(m), classExampleCounts(classCount);
00865 
00866     // verify the input labels
00867     int minClassCount;
00868     {
00869         typedef std::map<ClassLabelType, int > LabelChecker;
00870         typedef typename LabelChecker::iterator LabelCheckerIterator;
00871         LabelChecker labelChecker;
00872         for(unsigned int k=0; k<classCount; ++k)
00873             labelChecker[classes_[k]] = k;
00874 
00875         for(unsigned int k=0; k<m; ++k)
00876         {
00877             LabelCheckerIterator found = labelChecker.find(labels[k]);
00878             vigra_precondition(found != labelChecker.end(),
00879                 "RandomForest::learn(): Unknown class label encountered.");
00880             intLabels[k] = found->second;
00881             ++classExampleCounts[intLabels[k]];
00882         }
00883         minClassCount = *argMin(classExampleCounts.begin(), classExampleCounts.end());
00884         vigra_precondition(minClassCount > 0,
00885              "RandomForest::learn(): At least one class is missing in the training set.");
00886         if(msamples > 0 && options_.sample_classes_individually &&
00887                           !options_.sample_with_replacement)
00888         {
00889             vigra_precondition(msamples <= minClassCount,
00890                 "RandomForest::learn(): Too few examples in smallest class to reach "
00891                 "requested training set size.");
00892         }
00893     }
00894     columnCount_ = n;
00895     ArrayVector<int> indices(m);
00896     for(unsigned int k=0; k<m; ++k)
00897         indices[k] = k;
00898 
00899     if(options_.sample_classes_individually)
00900     {
00901         detail::RandomForestLabelSorter<ArrayVector<int> > sorter(intLabels);
00902         std::sort(indices.begin(), indices.end(), sorter);
00903     }
00904 
00905     ArrayVector<int> usedIndices(m), oobCount(m), oobErrorCount(m);
00906 
00907     UniformIntRandomFunctor<Random> randint(0, m-1, random);
00908     //std::cerr << "Learning a RF \n";
00909     for(unsigned int k=0; k<trees_.size(); ++k)
00910     {
00911        //std::cerr << "Learning tree " << k << " ...\n";
00912 
00913         ArrayVector<int> trainingSet;
00914         usedIndices.init(0);
00915 
00916         if(options_.sample_classes_individually)
00917         {
00918             int first = 0;
00919             for(unsigned int l=0; l<classCount; ++l)
00920             {
00921                 int lc = classExampleCounts[l];
00922                 int lsamples = (msamples == 0)
00923                                    ? int(std::ceil(options_.training_set_proportion*lc))
00924                                    : msamples;
00925 
00926                 if(options_.sample_with_replacement)
00927                 {
00928                     for(int ll=0; ll<lsamples; ++ll)
00929                     {
00930                         trainingSet.push_back(indices[first+randint(lc)]);
00931                         ++usedIndices[trainingSet.back()];
00932                     }
00933                 }
00934                 else
00935                 {
00936                     for(int ll=0; ll<lsamples; ++ll)
00937                     {
00938                         std::swap(indices[first+ll], indices[first+ll+randint(lc-ll)]);
00939                         trainingSet.push_back(indices[first+ll]);
00940                         ++usedIndices[trainingSet.back()];
00941                     }
00942                     //std::sort(indices.begin(), indices.begin()+lsamples);
00943                 }
00944                 first += lc;
00945             }
00946         }
00947         else
00948         {
00949             if(msamples == 0)
00950                 msamples = int(std::ceil(options_.training_set_proportion*m));
00951 
00952             if(options_.sample_with_replacement)
00953             {
00954                 for(int l=0; l<msamples; ++l)
00955                 {
00956                     trainingSet.push_back(indices[randint(m)]);
00957                     ++usedIndices[trainingSet.back()];
00958                 }
00959             }
00960             else
00961             {
00962                 for(int l=0; l<msamples; ++l)
00963                 {
00964                     std::swap(indices[l], indices[l+randint(m-l)/*oikas*/]);
00965                     trainingSet.push_back(indices[l]);
00966                     ++usedIndices[trainingSet.back()];
00967                 }
00968 
00969 
00970             }
00971 
00972         }
00973         trees_[k].learn(features, intLabels,
00974                         trainingSet.begin(), trainingSet.size(),
00975                         options_.featuresPerNode(mtry), randint);
00976 //        for(unsigned int l=0; l<m; ++l)
00977 //        {
00978 //            if(!usedIndices[l])
00979 //            {
00980 //                ++oobCount[l];
00981 //                if(trees_[k].predictLabel(rowVector(features, l)) != intLabels[l])
00982 //                    ++oobErrorCount[l];
00983 //            }
00984 //        }
00985 
00986         for(unsigned int l=0; l<m; ++l)
00987         {
00988             if(!usedIndices[l])
00989             {
00990                 ++oobCount[l];
00991                 if(trees_[k].predictLabel(rowVector(features, l)) != intLabels[l])
00992                 {
00993                     ++oobErrorCount[l];
00994                     if(options_.oob_data.data() != 0)
00995                         options_.oob_data(l, k) = 2;
00996                 }
00997                 else if(options_.oob_data.data() != 0)
00998                 {
00999                     options_.oob_data(l, k) = 1;
01000                 }
01001             }
01002         }
01003         // TODO: default value for oob_data
01004         // TODO: implement variable importance
01005         //if(!options_.sample_with_replacement){
01006         //std::cerr << "done\n";
01007         //trees_[k].print(std::cerr);
01008         #ifdef VIGRA_RF_VERBOSE
01009         trees_[k].printStatistics(std::cerr);
01010         #endif
01011     }
01012     double oobError = 0.0;
01013     int totalOobCount = 0;
01014     for(unsigned int l=0; l<m; ++l)
01015         if(oobCount[l])
01016         {
01017             oobError += double(oobErrorCount[l]) / oobCount[l];
01018             ++totalOobCount;
01019         }
01020     return oobError / totalOobCount;
01021 }
01022 
01023 template <class ClassLabelType>
01024 template <class U, class C>
01025 ClassLabelType
01026 RandomForest<ClassLabelType>::predictLabel(MultiArrayView<2, U, C> const & features) const
01027 {
01028     vigra_precondition(columnCount(features) >= featureCount(),
01029         "RandomForest::predictLabel(): Too few columns in feature matrix.");
01030     vigra_precondition(rowCount(features) == 1,
01031         "RandomForest::predictLabel(): Feature matrix must have a single row.");
01032     Matrix<double> prob(1, classes_.size());
01033     predictProbabilities(features, prob);
01034     return classes_[argMax(prob)];
01035 }
01036 
01037 
01038 //Same thing as above with priors for each label !!!
01039 template <class ClassLabelType>
01040 template <class U, class C, class Iterator>
01041 ClassLabelType
01042 RandomForest<ClassLabelType>::predictLabel(MultiArrayView<2, U, C> const & features,
01043                                            Iterator priors) const
01044 {
01045     using namespace functor;
01046     vigra_precondition(columnCount(features) >= featureCount(),
01047         "RandomForest::predictLabel(): Too few columns in feature matrix.");
01048     vigra_precondition(rowCount(features) == 1,
01049         "RandomForest::predictLabel(): Feature matrix must have a single row.");
01050     Matrix<double> prob(1,classes_.size());
01051     predictProbabilities(features, prob);
01052     std::transform(prob.begin(), prob.end(), priors, prob.begin(), Arg1()*Arg2());
01053     return classes_[argMax(prob)];
01054 }
01055 
01056 template <class ClassLabelType>
01057 template <class U, class C1, class T, class C2>
01058 void
01059 RandomForest<ClassLabelType>::predictProbabilities(MultiArrayView<2, U, C1> const & features,
01060                                                    MultiArrayView<2, T, C2> & prob) const
01061 {
01062 
01063     //Features are n xp
01064     //prob is n x NumOfLabel probability for each feature in each class
01065 
01066     vigra_precondition(rowCount(features) == rowCount(prob),
01067       "RandomForest::predictProbabilities(): Feature matrix and probability matrix size mismatch.");
01068 
01069     // num of features must be bigger than num of features in Random forest training
01070     // but why bigger?
01071     vigra_precondition(columnCount(features) >= featureCount(),
01072       "RandomForest::predictProbabilities(): Too few columns in feature matrix.");
01073     vigra_precondition(columnCount(prob) == (MultiArrayIndex)labelCount(),
01074       "RandomForest::predictProbabilities(): Probability matrix must have as many columns as there are classes.");
01075 
01076     //Classify for each row.
01077     for(int row=0; row < rowCount(features); ++row)
01078     {
01079     //contains the weights returned by a single tree???
01080     //thought that one tree has only one vote???
01081     //Pruning???
01082         ArrayVector<double>::const_iterator weights;
01083 
01084         //totalWeight == totalVoteCount!
01085     double totalWeight = 0.0;
01086 
01087     //Set each VoteCount = 0 - prob(row,l) contains vote counts until
01088     //further normalisation
01089         for(unsigned int l=0; l<classes_.size(); ++l)
01090             prob(row, l) = 0.0;
01091 
01092     //Let each tree classify...
01093         for(unsigned int k=0; k<trees_.size(); ++k)
01094         {
01095         //get weights predicted by single tree
01096             weights = trees_[k].predict(rowVector(features, row));
01097 
01098         //update votecount.
01099             for(unsigned int l=0; l<classes_.size(); ++l)
01100             {
01101                 prob(row, l) += detail::RequiresExplicitCast<T>::cast(weights[l]);
01102                 //every weight in totalWeight.
01103                 totalWeight += weights[l];
01104             }
01105         }
01106 
01107     //Normalise votes in each row by total VoteCount (totalWeight
01108         for(unsigned int l=0; l<classes_.size(); ++l)
01109                 prob(row, l) /= detail::RequiresExplicitCast<T>::cast(totalWeight);
01110     }
01111 }
01112 
01113 
01114 template <class ClassLabelType>
01115 template <class U, class C1, class T, class C2>
01116 void
01117 RandomForest<ClassLabelType>::predictNodes(MultiArrayView<2, U, C1> const & features,
01118                                                    MultiArrayView<2, T, C2> & NodeIDs) const
01119 {
01120     vigra_precondition(columnCount(features) >= featureCount(),
01121       "RandomForest::getNodesRF(): Too few columns in feature matrix.");
01122     vigra_precondition(rowCount(features) <= rowCount(NodeIDs),
01123       "RandomForest::getNodesRF(): Too few rows in NodeIds matrix");
01124     vigra_precondition(columnCount(NodeIDs) >= treeCount(),
01125       "RandomForest::getNodesRF(): Too few columns in NodeIds matrix.");
01126     NodeIDs.init(0);
01127     for(unsigned int k=0; k<trees_.size(); ++k)
01128     {
01129         for(int row=0; row < rowCount(features); ++row)
01130         {
01131             NodeIDs(row,k) = trees_[k].leafID(rowVector(features, row));
01132         }
01133     }
01134 }
01135 
01136 //@}
01137 
01138 } // namespace vigra
01139 
01140 #undef RandomForest
01141 #undef DecisionTree
01142 
01143 #endif // VIGRA_RANDOM_FOREST_HXX
01144 

© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de)
Heidelberg Collaboratory for Image Processing, University of Heidelberg, Germany

html generated using doxygen and Python
vigra 1.7.0 (Thu Aug 25 2011)