[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]
00001 /************************************************************************/ 00002 /* */ 00003 /* Copyright 2008 by Ullrich Koethe */ 00004 /* */ 00005 /* This file is part of the VIGRA computer vision library. */ 00006 /* The VIGRA Website is */ 00007 /* http://hci.iwr.uni-heidelberg.de/vigra/ */ 00008 /* Please direct questions, bug reports, and contributions to */ 00009 /* ullrich.koethe@iwr.uni-heidelberg.de or */ 00010 /* vigra@informatik.uni-hamburg.de */ 00011 /* */ 00012 /* Permission is hereby granted, free of charge, to any person */ 00013 /* obtaining a copy of this software and associated documentation */ 00014 /* files (the "Software"), to deal in the Software without */ 00015 /* restriction, including without limitation the rights to use, */ 00016 /* copy, modify, merge, publish, distribute, sublicense, and/or */ 00017 /* sell copies of the Software, and to permit persons to whom the */ 00018 /* Software is furnished to do so, subject to the following */ 00019 /* conditions: */ 00020 /* */ 00021 /* The above copyright notice and this permission notice shall be */ 00022 /* included in all copies or substantial portions of the */ 00023 /* Software. */ 00024 /* */ 00025 /* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND */ 00026 /* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES */ 00027 /* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND */ 00028 /* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT */ 00029 /* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, */ 00030 /* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING */ 00031 /* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR */ 00032 /* OTHER DEALINGS IN THE SOFTWARE. */ 00033 /* */ 00034 /************************************************************************/ 00035 00036 00037 #ifndef VIGRA_RANDOM_FOREST_HXX 00038 #define VIGRA_RANDOM_FOREST_HXX 00039 00040 #include <algorithm> 00041 #include <map> 00042 #include <numeric> 00043 #include <iostream> 00044 #include "vigra/mathutil.hxx" 00045 #include "vigra/array_vector.hxx" 00046 #include "vigra/sized_int.hxx" 00047 #include "vigra/matrix.hxx" 00048 #include "vigra/random.hxx" 00049 #include "vigra/functorexpression.hxx" 00050 00051 #define RandomForest RandomForestDeprec 00052 #define DecisionTree DecisionTreeDeprec 00053 00054 namespace vigra 00055 { 00056 00057 /** \addtogroup MachineLearning 00058 **/ 00059 //@{ 00060 00061 namespace detail 00062 { 00063 00064 template<class DataMatrix> 00065 class RandomForestFeatureSorter 00066 { 00067 DataMatrix const & data_; 00068 MultiArrayIndex sortColumn_; 00069 00070 public: 00071 00072 RandomForestFeatureSorter(DataMatrix const & data, MultiArrayIndex sortColumn) 00073 : data_(data), 00074 sortColumn_(sortColumn) 00075 {} 00076 00077 void setColumn(MultiArrayIndex sortColumn) 00078 { 00079 sortColumn_ = sortColumn; 00080 } 00081 00082 bool operator()(MultiArrayIndex l, MultiArrayIndex r) const 00083 { 00084 return data_(l, sortColumn_) < data_(r, sortColumn_); 00085 } 00086 }; 00087 00088 template<class LabelArray> 00089 class RandomForestLabelSorter 00090 { 00091 LabelArray const & labels_; 00092 00093 public: 00094 00095 RandomForestLabelSorter(LabelArray const & labels) 00096 : labels_(labels) 00097 {} 00098 00099 bool operator()(MultiArrayIndex l, MultiArrayIndex r) const 00100 { 00101 return labels_[l] < labels_[r]; 00102 } 00103 }; 00104 00105 template <class CountArray> 00106 class RandomForestClassCounter 00107 { 00108 ArrayVector<int> const & labels_; 00109 CountArray & counts_; 00110 00111 public: 00112 00113 RandomForestClassCounter(ArrayVector<int> const & labels, CountArray & counts) 00114 : labels_(labels), 00115 counts_(counts) 00116 { 00117 reset(); 00118 } 00119 00120 void reset() 00121 { 00122 counts_.init(0); 00123 } 00124 00125 void operator()(MultiArrayIndex l) const 00126 { 00127 ++counts_[labels_[l]]; 00128 } 00129 }; 00130 00131 struct DecisionTreeCountNonzeroFunctor 00132 { 00133 double operator()(double old, double other) const 00134 { 00135 if(other != 0.0) 00136 ++old; 00137 return old; 00138 } 00139 }; 00140 00141 struct DecisionTreeNode 00142 { 00143 DecisionTreeNode(int t, MultiArrayIndex bestColumn) 00144 : thresholdIndex(t), splitColumn(bestColumn) 00145 {} 00146 00147 int children[2]; 00148 int thresholdIndex; 00149 Int32 splitColumn; 00150 }; 00151 00152 template <class INT> 00153 struct DecisionTreeNodeProxy 00154 { 00155 DecisionTreeNodeProxy(ArrayVector<INT> const & tree, INT n) 00156 : node(const_cast<ArrayVector<INT> &>(tree).begin()+n) 00157 {} 00158 00159 INT & child(INT l) const 00160 { 00161 return node[l]; 00162 } 00163 00164 INT & decisionWeightsIndex() const 00165 { 00166 return node[2]; 00167 } 00168 00169 typename ArrayVector<INT>::iterator decisionColumns() const 00170 { 00171 return node+3; 00172 } 00173 00174 mutable typename ArrayVector<INT>::iterator node; 00175 }; 00176 00177 struct DecisionTreeAxisSplitFunctor 00178 { 00179 ArrayVector<Int32> splitColumns; 00180 ArrayVector<double> classCounts, currentCounts[2], bestCounts[2], classWeights; 00181 double threshold; 00182 double totalCounts[2], bestTotalCounts[2]; 00183 int mtry, classCount, bestSplitColumn; 00184 bool pure[2], isWeighted; 00185 00186 void init(int mtry, int cols, int classCount, ArrayVector<double> const & weights) 00187 { 00188 this->mtry = mtry; 00189 splitColumns.resize(cols); 00190 for(int k=0; k<cols; ++k) 00191 splitColumns[k] = k; 00192 00193 this->classCount = classCount; 00194 classCounts.resize(classCount); 00195 currentCounts[0].resize(classCount); 00196 currentCounts[1].resize(classCount); 00197 bestCounts[0].resize(classCount); 00198 bestCounts[1].resize(classCount); 00199 00200 isWeighted = weights.size() > 0; 00201 if(isWeighted) 00202 classWeights = weights; 00203 else 00204 classWeights.resize(classCount, 1.0); 00205 } 00206 00207 bool isPure(int k) const 00208 { 00209 return pure[k]; 00210 } 00211 00212 unsigned int totalCount(int k) const 00213 { 00214 return (unsigned int)bestTotalCounts[k]; 00215 } 00216 00217 int sizeofNode() const { return 4; } 00218 00219 int writeSplitParameters(ArrayVector<Int32> & tree, 00220 ArrayVector<double> &terminalWeights) 00221 { 00222 int currentWeightIndex = terminalWeights.size(); 00223 terminalWeights.push_back(threshold); 00224 00225 int currentNodeIndex = tree.size(); 00226 tree.push_back(-1); // left child 00227 tree.push_back(-1); // right child 00228 tree.push_back(currentWeightIndex); 00229 tree.push_back(bestSplitColumn); 00230 00231 return currentNodeIndex; 00232 } 00233 00234 void writeWeights(int l, ArrayVector<double> &terminalWeights) 00235 { 00236 for(int k=0; k<classCount; ++k) 00237 terminalWeights.push_back(isWeighted 00238 ? bestCounts[l][k] 00239 : bestCounts[l][k] / totalCount(l)); 00240 } 00241 00242 template <class U, class C, class AxesIterator, class WeightIterator> 00243 bool decideAtNode(MultiArrayView<2, U, C> const & features, 00244 AxesIterator a, WeightIterator w) const 00245 { 00246 return (features(0, *a) < *w); 00247 } 00248 00249 template <class U, class C, class IndexIterator, class Random> 00250 IndexIterator findBestSplit(MultiArrayView<2, U, C> const & features, 00251 ArrayVector<int> const & labels, 00252 IndexIterator indices, int exampleCount, 00253 Random & randint); 00254 00255 }; 00256 00257 00258 template <class U, class C, class IndexIterator, class Random> 00259 IndexIterator 00260 DecisionTreeAxisSplitFunctor::findBestSplit(MultiArrayView<2, U, C> const & features, 00261 ArrayVector<int> const & labels, 00262 IndexIterator indices, int exampleCount, 00263 Random & randint) 00264 { 00265 // select columns to be tried for split 00266 for(int k=0; k<mtry; ++k) 00267 std::swap(splitColumns[k], splitColumns[k+randint(columnCount(features)-k)]); 00268 00269 RandomForestFeatureSorter<MultiArrayView<2, U, C> > sorter(features, 0); 00270 RandomForestClassCounter<ArrayVector<double> > counter(labels, classCounts); 00271 std::for_each(indices, indices+exampleCount, counter); 00272 00273 // find the best gini index 00274 double minGini = NumericTraits<double>::max(); 00275 IndexIterator bestSplit; 00276 for(int k=0; k<mtry; ++k) 00277 { 00278 sorter.setColumn(splitColumns[k]); 00279 std::sort(indices, indices+exampleCount, sorter); 00280 00281 currentCounts[0].init(0); 00282 std::transform(classCounts.begin(), classCounts.end(), classWeights.begin(), 00283 currentCounts[1].begin(), std::multiplies<double>()); 00284 totalCounts[0] = 0; 00285 totalCounts[1] = std::accumulate(currentCounts[1].begin(), currentCounts[1].end(), 0.0); 00286 for(int m = 0; m < exampleCount-1; ++m) 00287 { 00288 int label = labels[indices[m]]; 00289 double w = classWeights[label]; 00290 currentCounts[0][label] += w; 00291 totalCounts[0] += w; 00292 currentCounts[1][label] -= w; 00293 totalCounts[1] -= w; 00294 00295 if (m < exampleCount-2 && 00296 features(indices[m], splitColumns[k]) == features(indices[m+1], splitColumns[k])) 00297 continue ; 00298 00299 double gini = 0.0; 00300 if(classCount == 2) 00301 { 00302 gini = currentCounts[0][0]*currentCounts[0][1] / totalCounts[0] + 00303 currentCounts[1][0]*currentCounts[1][1] / totalCounts[1]; 00304 } 00305 else 00306 { 00307 for(int l=0; l<classCount; ++l) 00308 gini += currentCounts[0][l]*(1.0 - currentCounts[0][l] / totalCounts[0]) + 00309 currentCounts[1][l]*(1.0 - currentCounts[1][l] / totalCounts[1]); 00310 } 00311 if(gini < minGini) 00312 { 00313 minGini = gini; 00314 bestSplit = indices+m; 00315 bestSplitColumn = splitColumns[k]; 00316 bestCounts[0] = currentCounts[0]; 00317 bestCounts[1] = currentCounts[1]; 00318 } 00319 } 00320 00321 00322 00323 } 00324 //std::cerr << minGini << " " << bestSplitColumn << std::endl; 00325 // split using the best feature 00326 sorter.setColumn(bestSplitColumn); 00327 std::sort(indices, indices+exampleCount, sorter); 00328 00329 for(int k=0; k<2; ++k) 00330 { 00331 bestTotalCounts[k] = std::accumulate(bestCounts[k].begin(), bestCounts[k].end(), 0.0); 00332 } 00333 00334 threshold = (features(bestSplit[0], bestSplitColumn) + features(bestSplit[1], bestSplitColumn)) / 2.0; 00335 ++bestSplit; 00336 00337 counter.reset(); 00338 std::for_each(indices, bestSplit, counter); 00339 pure[0] = 1.0 == std::accumulate(classCounts.begin(), classCounts.end(), 0.0, DecisionTreeCountNonzeroFunctor()); 00340 counter.reset(); 00341 std::for_each(bestSplit, indices+exampleCount, counter); 00342 pure[1] = 1.0 == std::accumulate(classCounts.begin(), classCounts.end(), 0.0, DecisionTreeCountNonzeroFunctor()); 00343 00344 return bestSplit; 00345 } 00346 00347 enum { DecisionTreeNoParent = -1 }; 00348 00349 template <class Iterator> 00350 struct DecisionTreeStackEntry 00351 { 00352 DecisionTreeStackEntry(Iterator i, int c, 00353 int lp = DecisionTreeNoParent, int rp = DecisionTreeNoParent) 00354 : indices(i), exampleCount(c), 00355 leftParent(lp), rightParent(rp) 00356 {} 00357 00358 Iterator indices; 00359 int exampleCount, leftParent, rightParent; 00360 }; 00361 00362 class DecisionTree 00363 { 00364 public: 00365 typedef Int32 TreeInt; 00366 ArrayVector<TreeInt> tree_; 00367 ArrayVector<double> terminalWeights_; 00368 unsigned int classCount_; 00369 DecisionTreeAxisSplitFunctor split; 00370 00371 public: 00372 00373 00374 DecisionTree(unsigned int classCount) 00375 : classCount_(classCount) 00376 {} 00377 00378 void reset(unsigned int classCount = 0) 00379 { 00380 if(classCount) 00381 classCount_ = classCount; 00382 tree_.clear(); 00383 terminalWeights_.clear(); 00384 } 00385 00386 template <class U, class C, class Iterator, class Options, class Random> 00387 void learn(MultiArrayView<2, U, C> const & features, 00388 ArrayVector<int> const & labels, 00389 Iterator indices, int exampleCount, 00390 Options const & options, 00391 Random & randint); 00392 00393 template <class U, class C> 00394 ArrayVector<double>::const_iterator 00395 predict(MultiArrayView<2, U, C> const & features) const 00396 { 00397 int nodeindex = 0; 00398 for(;;) 00399 { 00400 DecisionTreeNodeProxy<TreeInt> node(tree_, nodeindex); 00401 nodeindex = split.decideAtNode(features, node.decisionColumns(), 00402 terminalWeights_.begin() + node.decisionWeightsIndex()) 00403 ? node.child(0) 00404 : node.child(1); 00405 if(nodeindex <= 0) 00406 return terminalWeights_.begin() + (-nodeindex); 00407 } 00408 } 00409 00410 template <class U, class C> 00411 int 00412 predictLabel(MultiArrayView<2, U, C> const & features) const 00413 { 00414 ArrayVector<double>::const_iterator weights = predict(features); 00415 return argMax(weights, weights+classCount_) - weights; 00416 } 00417 00418 template <class U, class C> 00419 int 00420 leafID(MultiArrayView<2, U, C> const & features) const 00421 { 00422 int nodeindex = 0; 00423 for(;;) 00424 { 00425 DecisionTreeNodeProxy<TreeInt> node(tree_, nodeindex); 00426 nodeindex = split.decideAtNode(features, node.decisionColumns(), 00427 terminalWeights_.begin() + node.decisionWeightsIndex()) 00428 ? node.child(0) 00429 : node.child(1); 00430 if(nodeindex <= 0) 00431 return -nodeindex; 00432 } 00433 } 00434 00435 void depth(int & maxDep, int & interiorCount, int & leafCount, int k = 0, int d = 1) const 00436 { 00437 DecisionTreeNodeProxy<TreeInt> node(tree_, k); 00438 ++interiorCount; 00439 ++d; 00440 for(int l=0; l<2; ++l) 00441 { 00442 int child = node.child(l); 00443 if(child > 0) 00444 depth(maxDep, interiorCount, leafCount, child, d); 00445 else 00446 { 00447 ++leafCount; 00448 if(maxDep < d) 00449 maxDep = d; 00450 } 00451 } 00452 } 00453 00454 void printStatistics(std::ostream & o) const 00455 { 00456 int maxDep = 0, interiorCount = 0, leafCount = 0; 00457 depth(maxDep, interiorCount, leafCount); 00458 00459 o << "interior nodes: " << interiorCount << 00460 ", terminal nodes: " << leafCount << 00461 ", depth: " << maxDep << "\n"; 00462 } 00463 00464 void print(std::ostream & o, int k = 0, std::string s = "") const 00465 { 00466 DecisionTreeNodeProxy<TreeInt> node(tree_, k); 00467 o << s << (*node.decisionColumns()) << " " << terminalWeights_[node.decisionWeightsIndex()] << "\n"; 00468 00469 for(int l=0; l<2; ++l) 00470 { 00471 int child = node.child(l); 00472 if(child <= 0) 00473 o << s << " weights " << terminalWeights_[-child] << " " 00474 << terminalWeights_[-child+1] << "\n"; 00475 else 00476 print(o, child, s+" "); 00477 } 00478 } 00479 }; 00480 00481 00482 template <class U, class C, class Iterator, class Options, class Random> 00483 void DecisionTree::learn(MultiArrayView<2, U, C> const & features, 00484 ArrayVector<int> const & labels, 00485 Iterator indices, int exampleCount, 00486 Options const & options, 00487 Random & randint) 00488 { 00489 ArrayVector<double> const & classLoss = options.class_weights; 00490 00491 vigra_precondition(classLoss.size() == 0 || classLoss.size() == classCount_, 00492 "DecisionTree2::learn(): class weights array has wrong size."); 00493 00494 reset(); 00495 00496 unsigned int mtry = options.mtry; 00497 MultiArrayIndex cols = columnCount(features); 00498 00499 split.init(mtry, cols, classCount_, classLoss); 00500 00501 typedef DecisionTreeStackEntry<Iterator> Entry; 00502 ArrayVector<Entry> stack; 00503 stack.push_back(Entry(indices, exampleCount)); 00504 00505 while(!stack.empty()) 00506 { 00507 // std::cerr << "*"; 00508 indices = stack.back().indices; 00509 exampleCount = stack.back().exampleCount; 00510 int leftParent = stack.back().leftParent, 00511 rightParent = stack.back().rightParent; 00512 00513 stack.pop_back(); 00514 00515 Iterator bestSplit = split.findBestSplit(features, labels, indices, exampleCount, randint); 00516 00517 00518 int currentNode = split.writeSplitParameters(tree_, terminalWeights_); 00519 00520 if(leftParent != DecisionTreeNoParent) 00521 DecisionTreeNodeProxy<TreeInt>(tree_, leftParent).child(0) = currentNode; 00522 if(rightParent != DecisionTreeNoParent) 00523 DecisionTreeNodeProxy<TreeInt>(tree_, rightParent).child(1) = currentNode; 00524 leftParent = currentNode; 00525 rightParent = DecisionTreeNoParent; 00526 00527 for(int l=0; l<2; ++l) 00528 { 00529 00530 if(!split.isPure(l) && split.totalCount(l) >= options.min_split_node_size) 00531 { 00532 // sample is still large enough and not yet perfectly separated => split 00533 stack.push_back(Entry(indices, split.totalCount(l), leftParent, rightParent)); 00534 } 00535 else 00536 { 00537 DecisionTreeNodeProxy<TreeInt>(tree_, currentNode).child(l) = -(TreeInt)terminalWeights_.size(); 00538 00539 split.writeWeights(l, terminalWeights_); 00540 } 00541 std::swap(leftParent, rightParent); 00542 indices = bestSplit; 00543 } 00544 } 00545 // std::cerr << "\n"; 00546 } 00547 00548 } // namespace detail 00549 00550 class RandomForestOptions 00551 { 00552 public: 00553 /** Initialize all options with default values. 00554 */ 00555 RandomForestOptions() 00556 : training_set_proportion(1.0), 00557 mtry(0), 00558 min_split_node_size(1), 00559 training_set_size(0), 00560 sample_with_replacement(true), 00561 sample_classes_individually(false), 00562 treeCount(255) 00563 {} 00564 00565 /** Number of features considered in each node. 00566 00567 If \a n is 0 (the default), the number of features tried in every node 00568 is determined by the square root of the total number of features. 00569 According to Breiman, this quantity should slways be optimized by means 00570 of the out-of-bag error.<br> 00571 Default: 0 (use <tt>sqrt(columnCount(featureMatrix))</tt>) 00572 */ 00573 RandomForestOptions & featuresPerNode(unsigned int n) 00574 { 00575 mtry = n; 00576 return *this; 00577 } 00578 00579 /** How to sample the subset of the training data for each tree. 00580 00581 Each tree is only trained with a subset of the entire training data. 00582 If \a r is <tt>true</tt>, this subset is sampled from the entire training set with 00583 replacement.<br> 00584 Default: <tt>true</tt> (use sampling with replacement)) 00585 */ 00586 RandomForestOptions & sampleWithReplacement(bool r) 00587 { 00588 sample_with_replacement = r; 00589 return *this; 00590 } 00591 00592 RandomForestOptions & setTreeCount(unsigned int cnt) 00593 { 00594 treeCount = cnt; 00595 return *this; 00596 } 00597 /** Proportion of training examples used for each tree. 00598 00599 If \a p is 1.0 (the default), and samples are drawn with replacement, 00600 the training set of each tree will contain as many examples as the entire 00601 training set, but some are drawn multiply and others not at all. On average, 00602 each tree is actually trained on about 65% of the examples in the full 00603 training set. Changing the proportion makes mainly sense when 00604 sampleWithReplacement() is set to <tt>false</tt>. trainingSetSizeProportional() gets 00605 overridden by trainingSetSizeAbsolute().<br> 00606 Default: 1.0 00607 */ 00608 RandomForestOptions & trainingSetSizeProportional(double p) 00609 { 00610 vigra_precondition(p >= 0.0 && p <= 1.0, 00611 "RandomForestOptions::trainingSetSizeProportional(): proportion must be in [0, 1]."); 00612 if(training_set_size == 0) // otherwise, absolute size gets priority 00613 training_set_proportion = p; 00614 return *this; 00615 } 00616 00617 /** Size of the training set for each tree. 00618 00619 If this option is set, it overrides the proportion set by 00620 trainingSetSizeProportional(). When classes are sampled individually, 00621 the number of examples is divided by the number of classes (rounded upwards) 00622 to determine the number of examples drawn from every class.<br> 00623 Default: <tt>0</tt> (determine size by proportion) 00624 */ 00625 RandomForestOptions & trainingSetSizeAbsolute(unsigned int s) 00626 { 00627 training_set_size = s; 00628 if(s > 0) 00629 training_set_proportion = 0.0; 00630 return *this; 00631 } 00632 00633 /** Are the classes sampled individually? 00634 00635 If \a s is <tt>false</tt> (the default), the training set for each tree is sampled 00636 without considering class labels. Otherwise, samples are drawn from each 00637 class independently. The latter is especially useful in connection 00638 with the specification of an absolute training set size: then, the same number of 00639 examples is drawn from every class. This can be used as a counter-measure when the 00640 classes are very unbalanced in size.<br> 00641 Default: <tt>false</tt> 00642 */ 00643 RandomForestOptions & sampleClassesIndividually(bool s) 00644 { 00645 sample_classes_individually = s; 00646 return *this; 00647 } 00648 00649 /** Number of examples required for a node to be split. 00650 00651 When the number of examples in a node is below this number, the node is not 00652 split even if class separation is not yet perfect. Instead, the node returns 00653 the proportion of each class (among the remaining examples) during the 00654 prediction phase.<br> 00655 Default: 1 (complete growing) 00656 */ 00657 RandomForestOptions & minSplitNodeSize(unsigned int n) 00658 { 00659 if(n == 0) 00660 n = 1; 00661 min_split_node_size = n; 00662 return *this; 00663 } 00664 00665 /** Use a weighted random forest. 00666 00667 This is usually used to penalize the errors for the minority class. 00668 Weights must be convertible to <tt>double</tt>, and the array of weights 00669 must contain as many entries as there are classes.<br> 00670 Default: do not use weights 00671 */ 00672 template <class WeightIterator> 00673 RandomForestOptions & weights(WeightIterator weights, unsigned int classCount) 00674 { 00675 class_weights.clear(); 00676 if(weights != 0) 00677 class_weights.insert(weights, classCount); 00678 return *this; 00679 } 00680 00681 RandomForestOptions & oobData(MultiArrayView<2, UInt8>& data) 00682 { 00683 oob_data =data; 00684 return *this; 00685 } 00686 00687 MultiArrayView<2, UInt8> oob_data; 00688 ArrayVector<double> class_weights; 00689 double training_set_proportion; 00690 unsigned int mtry, min_split_node_size, training_set_size; 00691 bool sample_with_replacement, sample_classes_individually; 00692 unsigned int treeCount; 00693 }; 00694 00695 /*****************************************************************/ 00696 /* */ 00697 /* RandomForest */ 00698 /* */ 00699 /*****************************************************************/ 00700 00701 template <class ClassLabelType> 00702 class RandomForest 00703 { 00704 public: 00705 ArrayVector<ClassLabelType> classes_; 00706 ArrayVector<detail::DecisionTree> trees_; 00707 MultiArrayIndex columnCount_; 00708 RandomForestOptions options_; 00709 00710 public: 00711 00712 //First two constructors are straight forward. 00713 //they take either the iterators to an Array of Classlabels or the values 00714 template<class ClassLabelIterator> 00715 RandomForest(ClassLabelIterator cl, ClassLabelIterator cend, 00716 unsigned int treeCount = 255, 00717 RandomForestOptions const & options = RandomForestOptions()) 00718 : classes_(cl, cend), 00719 trees_(treeCount, detail::DecisionTree(classes_.size())), 00720 columnCount_(0), 00721 options_(options) 00722 { 00723 vigra_precondition(options.training_set_proportion == 0.0 || 00724 options.training_set_size == 0, 00725 "RandomForestOptions: absolute and proprtional training set sizes " 00726 "cannot be specified at the same time."); 00727 vigra_precondition(classes_.size() > 1, 00728 "RandomForestOptions::weights(): need at least two classes."); 00729 vigra_precondition(options.class_weights.size() == 0 || options.class_weights.size() == classes_.size(), 00730 "RandomForestOptions::weights(): wrong number of classes."); 00731 } 00732 00733 RandomForest(ClassLabelType const & c1, ClassLabelType const & c2, 00734 unsigned int treeCount = 255, 00735 RandomForestOptions const & options = RandomForestOptions()) 00736 : classes_(2), 00737 trees_(treeCount, detail::DecisionTree(2)), 00738 columnCount_(0), 00739 options_(options) 00740 { 00741 vigra_precondition(options.class_weights.size() == 0 || options.class_weights.size() == 2, 00742 "RandomForestOptions::weights(): wrong number of classes."); 00743 classes_[0] = c1; 00744 classes_[1] = c2; 00745 } 00746 //This is esp. For the CrosValidator Class 00747 template<class ClassLabelIterator> 00748 RandomForest(ClassLabelIterator cl, ClassLabelIterator cend, 00749 RandomForestOptions const & options ) 00750 : classes_(cl, cend), 00751 trees_(options.treeCount , detail::DecisionTree(classes_.size())), 00752 columnCount_(0), 00753 options_(options) 00754 { 00755 00756 vigra_precondition(options.training_set_proportion == 0.0 || 00757 options.training_set_size == 0, 00758 "RandomForestOptions: absolute and proprtional training set sizes " 00759 "cannot be specified at the same time."); 00760 vigra_precondition(classes_.size() > 1, 00761 "RandomForestOptions::weights(): need at least two classes."); 00762 vigra_precondition(options.class_weights.size() == 0 || options.class_weights.size() == classes_.size(), 00763 "RandomForestOptions::weights(): wrong number of classes."); 00764 } 00765 00766 //Not understood yet 00767 //Does not use the options object but the columnCount object. 00768 template<class ClassLabelIterator, class TreeIterator, class WeightIterator> 00769 RandomForest(ClassLabelIterator cl, ClassLabelIterator cend, 00770 unsigned int treeCount, unsigned int columnCount, 00771 TreeIterator trees, WeightIterator weights) 00772 : classes_(cl, cend), 00773 trees_(treeCount, detail::DecisionTree(classes_.size())), 00774 columnCount_(columnCount) 00775 { 00776 for(unsigned int k=0; k<treeCount; ++k, ++trees, ++weights) 00777 { 00778 trees_[k].tree_ = *trees; 00779 trees_[k].terminalWeights_ = *weights; 00780 } 00781 } 00782 00783 int featureCount() const 00784 { 00785 vigra_precondition(columnCount_ > 0, 00786 "RandomForest::featureCount(): Random forest has not been trained yet."); 00787 return columnCount_; 00788 } 00789 00790 int labelCount() const 00791 { 00792 return classes_.size(); 00793 } 00794 00795 int treeCount() const 00796 { 00797 return trees_.size(); 00798 } 00799 00800 // loss == 0.0 means unweighted random forest 00801 template <class U, class C, class Array, class Random> 00802 double learn(MultiArrayView<2, U, C> const & features, Array const & labels, 00803 Random const& random); 00804 00805 template <class U, class C, class Array> 00806 double learn(MultiArrayView<2, U, C> const & features, Array const & labels) 00807 { 00808 return learn(features, labels, RandomTT800::global()); 00809 } 00810 00811 template <class U, class C> 00812 ClassLabelType predictLabel(MultiArrayView<2, U, C> const & features) const; 00813 00814 template <class U, class C1, class T, class C2> 00815 void predictLabels(MultiArrayView<2, U, C1> const & features, 00816 MultiArrayView<2, T, C2> & labels) const 00817 { 00818 vigra_precondition(features.shape(0) == labels.shape(0), 00819 "RandomForest::predictLabels(): Label array has wrong size."); 00820 for(int k=0; k<features.shape(0); ++k) 00821 labels(k,0) = predictLabel(rowVector(features, k)); 00822 } 00823 00824 template <class U, class C, class Iterator> 00825 ClassLabelType predictLabel(MultiArrayView<2, U, C> const & features, 00826 Iterator priors) const; 00827 00828 template <class U, class C1, class T, class C2> 00829 void predictProbabilities(MultiArrayView<2, U, C1> const & features, 00830 MultiArrayView<2, T, C2> & prob) const; 00831 00832 template <class U, class C1, class T, class C2> 00833 void predictNodes(MultiArrayView<2, U, C1> const & features, 00834 MultiArrayView<2, T, C2> & NodeIDs) const; 00835 }; 00836 00837 template <class ClassLabelType> 00838 template <class U, class C1, class Array, class Random> 00839 double 00840 RandomForest<ClassLabelType>::learn(MultiArrayView<2, U, C1> const & features, 00841 Array const & labels, 00842 Random const& random) 00843 { 00844 unsigned int classCount = classes_.size(); 00845 unsigned int m = rowCount(features); 00846 unsigned int n = columnCount(features); 00847 vigra_precondition((unsigned int)(m) == (unsigned int)labels.size(), 00848 "RandomForest::learn(): Label array has wrong size."); 00849 00850 vigra_precondition(options_.training_set_size <= m || options_.sample_with_replacement, 00851 "RandomForest::learn(): Requested training set size exceeds total number of examples."); 00852 00853 MultiArrayIndex mtry = (options_.mtry == 0) 00854 ? int(std::floor(std::sqrt(double(n)) + 0.5)) 00855 : options_.mtry; 00856 00857 vigra_precondition(mtry <= (MultiArrayIndex)n, 00858 "RandomForest::learn(): mtry must be less than number of features."); 00859 00860 MultiArrayIndex msamples = options_.training_set_size; 00861 if(options_.sample_classes_individually) 00862 msamples = int(std::ceil(double(msamples) / classCount)); 00863 00864 ArrayVector<int> intLabels(m), classExampleCounts(classCount); 00865 00866 // verify the input labels 00867 int minClassCount; 00868 { 00869 typedef std::map<ClassLabelType, int > LabelChecker; 00870 typedef typename LabelChecker::iterator LabelCheckerIterator; 00871 LabelChecker labelChecker; 00872 for(unsigned int k=0; k<classCount; ++k) 00873 labelChecker[classes_[k]] = k; 00874 00875 for(unsigned int k=0; k<m; ++k) 00876 { 00877 LabelCheckerIterator found = labelChecker.find(labels[k]); 00878 vigra_precondition(found != labelChecker.end(), 00879 "RandomForest::learn(): Unknown class label encountered."); 00880 intLabels[k] = found->second; 00881 ++classExampleCounts[intLabels[k]]; 00882 } 00883 minClassCount = *argMin(classExampleCounts.begin(), classExampleCounts.end()); 00884 vigra_precondition(minClassCount > 0, 00885 "RandomForest::learn(): At least one class is missing in the training set."); 00886 if(msamples > 0 && options_.sample_classes_individually && 00887 !options_.sample_with_replacement) 00888 { 00889 vigra_precondition(msamples <= minClassCount, 00890 "RandomForest::learn(): Too few examples in smallest class to reach " 00891 "requested training set size."); 00892 } 00893 } 00894 columnCount_ = n; 00895 ArrayVector<int> indices(m); 00896 for(unsigned int k=0; k<m; ++k) 00897 indices[k] = k; 00898 00899 if(options_.sample_classes_individually) 00900 { 00901 detail::RandomForestLabelSorter<ArrayVector<int> > sorter(intLabels); 00902 std::sort(indices.begin(), indices.end(), sorter); 00903 } 00904 00905 ArrayVector<int> usedIndices(m), oobCount(m), oobErrorCount(m); 00906 00907 UniformIntRandomFunctor<Random> randint(0, m-1, random); 00908 //std::cerr << "Learning a RF \n"; 00909 for(unsigned int k=0; k<trees_.size(); ++k) 00910 { 00911 //std::cerr << "Learning tree " << k << " ...\n"; 00912 00913 ArrayVector<int> trainingSet; 00914 usedIndices.init(0); 00915 00916 if(options_.sample_classes_individually) 00917 { 00918 int first = 0; 00919 for(unsigned int l=0; l<classCount; ++l) 00920 { 00921 int lc = classExampleCounts[l]; 00922 int lsamples = (msamples == 0) 00923 ? int(std::ceil(options_.training_set_proportion*lc)) 00924 : msamples; 00925 00926 if(options_.sample_with_replacement) 00927 { 00928 for(int ll=0; ll<lsamples; ++ll) 00929 { 00930 trainingSet.push_back(indices[first+randint(lc)]); 00931 ++usedIndices[trainingSet.back()]; 00932 } 00933 } 00934 else 00935 { 00936 for(int ll=0; ll<lsamples; ++ll) 00937 { 00938 std::swap(indices[first+ll], indices[first+ll+randint(lc-ll)]); 00939 trainingSet.push_back(indices[first+ll]); 00940 ++usedIndices[trainingSet.back()]; 00941 } 00942 //std::sort(indices.begin(), indices.begin()+lsamples); 00943 } 00944 first += lc; 00945 } 00946 } 00947 else 00948 { 00949 if(msamples == 0) 00950 msamples = int(std::ceil(options_.training_set_proportion*m)); 00951 00952 if(options_.sample_with_replacement) 00953 { 00954 for(int l=0; l<msamples; ++l) 00955 { 00956 trainingSet.push_back(indices[randint(m)]); 00957 ++usedIndices[trainingSet.back()]; 00958 } 00959 } 00960 else 00961 { 00962 for(int l=0; l<msamples; ++l) 00963 { 00964 std::swap(indices[l], indices[l+randint(m-l)/*oikas*/]); 00965 trainingSet.push_back(indices[l]); 00966 ++usedIndices[trainingSet.back()]; 00967 } 00968 00969 00970 } 00971 00972 } 00973 trees_[k].learn(features, intLabels, 00974 trainingSet.begin(), trainingSet.size(), 00975 options_.featuresPerNode(mtry), randint); 00976 // for(unsigned int l=0; l<m; ++l) 00977 // { 00978 // if(!usedIndices[l]) 00979 // { 00980 // ++oobCount[l]; 00981 // if(trees_[k].predictLabel(rowVector(features, l)) != intLabels[l]) 00982 // ++oobErrorCount[l]; 00983 // } 00984 // } 00985 00986 for(unsigned int l=0; l<m; ++l) 00987 { 00988 if(!usedIndices[l]) 00989 { 00990 ++oobCount[l]; 00991 if(trees_[k].predictLabel(rowVector(features, l)) != intLabels[l]) 00992 { 00993 ++oobErrorCount[l]; 00994 if(options_.oob_data.data() != 0) 00995 options_.oob_data(l, k) = 2; 00996 } 00997 else if(options_.oob_data.data() != 0) 00998 { 00999 options_.oob_data(l, k) = 1; 01000 } 01001 } 01002 } 01003 // TODO: default value for oob_data 01004 // TODO: implement variable importance 01005 //if(!options_.sample_with_replacement){ 01006 //std::cerr << "done\n"; 01007 //trees_[k].print(std::cerr); 01008 #ifdef VIGRA_RF_VERBOSE 01009 trees_[k].printStatistics(std::cerr); 01010 #endif 01011 } 01012 double oobError = 0.0; 01013 int totalOobCount = 0; 01014 for(unsigned int l=0; l<m; ++l) 01015 if(oobCount[l]) 01016 { 01017 oobError += double(oobErrorCount[l]) / oobCount[l]; 01018 ++totalOobCount; 01019 } 01020 return oobError / totalOobCount; 01021 } 01022 01023 template <class ClassLabelType> 01024 template <class U, class C> 01025 ClassLabelType 01026 RandomForest<ClassLabelType>::predictLabel(MultiArrayView<2, U, C> const & features) const 01027 { 01028 vigra_precondition(columnCount(features) >= featureCount(), 01029 "RandomForest::predictLabel(): Too few columns in feature matrix."); 01030 vigra_precondition(rowCount(features) == 1, 01031 "RandomForest::predictLabel(): Feature matrix must have a single row."); 01032 Matrix<double> prob(1, classes_.size()); 01033 predictProbabilities(features, prob); 01034 return classes_[argMax(prob)]; 01035 } 01036 01037 01038 //Same thing as above with priors for each label !!! 01039 template <class ClassLabelType> 01040 template <class U, class C, class Iterator> 01041 ClassLabelType 01042 RandomForest<ClassLabelType>::predictLabel(MultiArrayView<2, U, C> const & features, 01043 Iterator priors) const 01044 { 01045 using namespace functor; 01046 vigra_precondition(columnCount(features) >= featureCount(), 01047 "RandomForest::predictLabel(): Too few columns in feature matrix."); 01048 vigra_precondition(rowCount(features) == 1, 01049 "RandomForest::predictLabel(): Feature matrix must have a single row."); 01050 Matrix<double> prob(1,classes_.size()); 01051 predictProbabilities(features, prob); 01052 std::transform(prob.begin(), prob.end(), priors, prob.begin(), Arg1()*Arg2()); 01053 return classes_[argMax(prob)]; 01054 } 01055 01056 template <class ClassLabelType> 01057 template <class U, class C1, class T, class C2> 01058 void 01059 RandomForest<ClassLabelType>::predictProbabilities(MultiArrayView<2, U, C1> const & features, 01060 MultiArrayView<2, T, C2> & prob) const 01061 { 01062 01063 //Features are n xp 01064 //prob is n x NumOfLabel probability for each feature in each class 01065 01066 vigra_precondition(rowCount(features) == rowCount(prob), 01067 "RandomForest::predictProbabilities(): Feature matrix and probability matrix size mismatch."); 01068 01069 // num of features must be bigger than num of features in Random forest training 01070 // but why bigger? 01071 vigra_precondition(columnCount(features) >= featureCount(), 01072 "RandomForest::predictProbabilities(): Too few columns in feature matrix."); 01073 vigra_precondition(columnCount(prob) == (MultiArrayIndex)labelCount(), 01074 "RandomForest::predictProbabilities(): Probability matrix must have as many columns as there are classes."); 01075 01076 //Classify for each row. 01077 for(int row=0; row < rowCount(features); ++row) 01078 { 01079 //contains the weights returned by a single tree??? 01080 //thought that one tree has only one vote??? 01081 //Pruning??? 01082 ArrayVector<double>::const_iterator weights; 01083 01084 //totalWeight == totalVoteCount! 01085 double totalWeight = 0.0; 01086 01087 //Set each VoteCount = 0 - prob(row,l) contains vote counts until 01088 //further normalisation 01089 for(unsigned int l=0; l<classes_.size(); ++l) 01090 prob(row, l) = 0.0; 01091 01092 //Let each tree classify... 01093 for(unsigned int k=0; k<trees_.size(); ++k) 01094 { 01095 //get weights predicted by single tree 01096 weights = trees_[k].predict(rowVector(features, row)); 01097 01098 //update votecount. 01099 for(unsigned int l=0; l<classes_.size(); ++l) 01100 { 01101 prob(row, l) += detail::RequiresExplicitCast<T>::cast(weights[l]); 01102 //every weight in totalWeight. 01103 totalWeight += weights[l]; 01104 } 01105 } 01106 01107 //Normalise votes in each row by total VoteCount (totalWeight 01108 for(unsigned int l=0; l<classes_.size(); ++l) 01109 prob(row, l) /= detail::RequiresExplicitCast<T>::cast(totalWeight); 01110 } 01111 } 01112 01113 01114 template <class ClassLabelType> 01115 template <class U, class C1, class T, class C2> 01116 void 01117 RandomForest<ClassLabelType>::predictNodes(MultiArrayView<2, U, C1> const & features, 01118 MultiArrayView<2, T, C2> & NodeIDs) const 01119 { 01120 vigra_precondition(columnCount(features) >= featureCount(), 01121 "RandomForest::getNodesRF(): Too few columns in feature matrix."); 01122 vigra_precondition(rowCount(features) <= rowCount(NodeIDs), 01123 "RandomForest::getNodesRF(): Too few rows in NodeIds matrix"); 01124 vigra_precondition(columnCount(NodeIDs) >= treeCount(), 01125 "RandomForest::getNodesRF(): Too few columns in NodeIds matrix."); 01126 NodeIDs.init(0); 01127 for(unsigned int k=0; k<trees_.size(); ++k) 01128 { 01129 for(int row=0; row < rowCount(features); ++row) 01130 { 01131 NodeIDs(row,k) = trees_[k].leafID(rowVector(features, row)); 01132 } 01133 } 01134 } 01135 01136 //@} 01137 01138 } // namespace vigra 01139 01140 #undef RandomForest 01141 #undef DecisionTree 01142 01143 #endif // VIGRA_RANDOM_FOREST_HXX 01144
© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de) |
html generated using doxygen and Python
|